[Mlir-commits] [mlir] [mlir][shard, mpi] Adding Shard/MPI reduce_scatter (PR #184189)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Mar 3 07:04:47 PST 2026
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/184189
>From 0698503f87e0be19b78d468ef9acbe5a0652459c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 26 Feb 2026 04:42:34 -0800
Subject: [PATCH 1/8] adding mpi::reduce_scatter_block, lowering to llvm and
tests
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 34 +++++++++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 76 ++++++++++++++++++-
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 10 +++
mlir/test/Dialect/MPI/mpiops.mlir | 6 ++
4 files changed, 125 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7e68b152fdf75..ee79026bf8214 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -333,6 +333,40 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
"(`->` type($retval)^)?";
}
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
+ let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
+ "recvcount, dtype, op, comm)`";
+ let description = [{
+ MPI_Reduce_scatter_block first performs an element-wise reduction on the
+ sendbuf across all processes in the communicator, then scatters the result
+ by distributing equal-sized blocks to each process into recvbuf.
+
+ The `op` attribute specifies the reduction operation to be performed.
+ Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
+ supported.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (
+ ins AnyMemRef : $sendbuf,
+ AnyMemRef : $recvbuf,
+ MPI_ReductionOpEnum : $op,
+ MPI_Comm : $comm
+ );
+
+ let results = (outs Optional<MPI_Retval>:$retval);
+
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+ "(`->` type($retval)^)?";
+}
+
//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 0dbc0a126a5c6..8571b39d41780 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -911,6 +912,79 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOpLowering
+//===----------------------------------------------------------------------===//
+
+struct ReduceScatterBlockOpLowering
+ : public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
+ Type elemType = op.getSendbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
+
+ // If input and output are the same, request in-place operation.
+ if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+ sendPtr = LLVM::ConstantOp::create(
+ rewriter, loc, i64,
+ reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
+ }
+
+ Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+ Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ // recvCnt should be the size of the block -> we need to divide by the
+ // number of ranks to get the count argument for the function call.
+ Value nRanks = createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+ Value recvCnt = LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
+
+ // Check that recvSize is a multiple of the number of ranks
+ Value checkSize = LLVM::MulOp::create(rewriter, loc, i32, recvCnt, nRanks);
+ Value sizeIsValid = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, checkSize, recvSize);
+ cf::AssertOp::create(rewriter, loc, sizeIsValid, "Output buffer's size must be a multiple of the number of ranks");
+
+ // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
+ // int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
+
+ // replace op with function call
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
+ ValueRange{sendPtr, recvPtr, recvCnt, dataType, mpiOp, comm});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@@ -943,7 +1017,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
SendOpLowering, RecvOpLowering, AllGatherOpLowering,
- AllReduceOpLowering>(converter);
+ AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 9ec81c53b41f8..7e35f331b51a2 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,7 @@
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
@@ -119,6 +120,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+ // CHECK: [[v73:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[v73]], {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
@@ -132,6 +137,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
@@ -240,6 +246,10 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+ // CHECK: [[v63:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[v63]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index 87a5647ee91d2..ee979d33c699d 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -77,6 +77,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+ // CHECK-NEXT: [[v10:%.*]] = mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ %err9 = mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+ // CHECK-NEXT: mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
+ mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
// CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval
>From 9bb1b80ad4b4cb95179803d6df11e439294e7d9c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 27 Feb 2026 06:43:08 -0800
Subject: [PATCH 2/8] simplify pass (was test-simplification) and lowering
shard.reduce_scatter to MPI
---
.../include/mlir/Dialect/Shard/IR/ShardOps.td | 20 +--
.../mlir/Dialect/Shard/Transforms/Partition.h | 2 +-
.../mlir/Dialect/Shard/Transforms/Passes.td | 17 ++
.../{Simplifications.h => Simplify.h} | 16 +-
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 153 +++++++++++++++++-
mlir/lib/Dialect/Shard/IR/ShardOps.cpp | 6 +-
.../Dialect/Shard/Transforms/CMakeLists.txt | 3 +-
.../{Simplifications.cpp => Simplify.cpp} | 108 +++++++++----
.../ShardToMPI/convert-shard-to-mpi.mlir | 34 ++++
.../Shard/all-scatter-op-lowering.mlir | 12 +-
mlir/test/Dialect/Shard/canonicalization.mlir | 8 +-
mlir/test/Dialect/Shard/folding.mlir | 2 +-
mlir/test/Dialect/Shard/invalid.mlir | 20 +--
mlir/test/Dialect/Shard/ops.mlir | 20 +--
.../{simplifications.mlir => simplify.mlir} | 85 +++++++++-
mlir/test/lib/Dialect/Shard/CMakeLists.txt | 1 -
.../Dialect/Shard/TestReshardingPartition.cpp | 2 +-
.../lib/Dialect/Shard/TestSimplifications.cpp | 47 ------
mlir/tools/mlir-opt/mlir-opt.cpp | 2 -
19 files changed, 423 insertions(+), 135 deletions(-)
rename mlir/include/mlir/Dialect/Shard/Transforms/{Simplifications.h => Simplify.h} (89%)
rename mlir/lib/Dialect/Shard/Transforms/{Simplifications.cpp => Simplify.cpp} (55%)
rename mlir/test/Dialect/Shard/{simplifications.mlir => simplify.mlir} (66%)
delete mode 100644 mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6ef7c72d305ee..60f6d6fc1ffe4 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
`grid_axes` using the specified reduction method. The reduction is performed
element-wise across the tensor pieces from all devices in the group.
After reduction, the reduction result is scattered (split and distributed)
- across the device group along `scatter_axis`.
+ across the device group along `scatter_dim`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
- reduction = <max> scatter_axis = 0
+ reduction = <max> scatter_dim = 0
: tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
@@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
```
}];
let arguments = !con(commonArgs, (ins
- AnyNon0RankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
- IndexAttr:$scatter_axis
+ IndexAttr:$scatter_dim
));
let results = (outs
- AnyRankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
- `scatter_axis` `=` $scatter_axis
+ `scatter_dim` `=` $scatter_dim
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
@@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
let summary = "Scatter over a device grid.";
let description = [{
For each device group defined by `grid_axes`, the input tensor on the `root`
- device is split along axis `scatter_axis` and distributed across the group.
+ device is split along axis `scatter_dim` and distributed across the group.
The content of the input on all other (non-root) devices is ignored.
The `root` device is defined by its in-group multi-index.
@@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
```
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
- scatter_axis = 0
+ scatter_dim = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
```
@@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- IndexAttr:$scatter_axis,
+ IndexAttr:$scatter_dim,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
- `scatter_axis` `=` $scatter_axis
+ `scatter_dim` `=` $scatter_dim
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 37903765903db..ba12002b9f1eb 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index bbc6a1977b13e..575c176217e61 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
];
}
+def ShardSimplify : Pass<"shard-simplify"> {
+ let summary = "Shard simplify patterns.";
+ let description = [{
+ Applies simplification patterns on the Shard dialect operations.
+ This includes:
+ - All-reduce endomorphism simplification, e.g. transforming
+ `all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
+ - Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
+ share the same grid and grid_axes.
+ - Folding static grid shapes into constants.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "shard::ShardDialect"
+ ];
+}
+
def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
similarity index 89%
rename from mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
index 45ae758ec14c2..f3f4feffd8a71 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -37,8 +37,8 @@ namespace shard {
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
// the algebraic structure.
template <typename AlgebraicOp>
-void populateAllReduceEndomorphismSimplificationPatterns(
- RewritePatternSet &patterns, ReductionKind reduction) {
+void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
+ ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
@@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTableCollection is used to cache them.
-void populateSimplificationPatterns(
- RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);
} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 1db14e60c5a7f..01196dd1b6fcd 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -26,7 +26,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -618,6 +618,155 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
}
};
+struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
+ using CommOpPattern::CommOpPattern;
+
+ // shard.reduce_scatter reduces and then scatters along a specified
+ // scatter-dim. mpi.reduce_scatter_block always scatters along the first
+ // dimension. Hence, if scatter-dim != 0, we need to rearrange the input
+ // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
+ // transposing nRanks to the first dimension.
+
+ LogicalResult
+ matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto gridAxes = adaptor.getGridAxes();
+ int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
+
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+
+ ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+ Value rawInput = adaptor.getInput();
+ auto inShapedType = cast<ShapedType>(rawInput.getType());
+ auto elemType = inShapedType.getElementType();
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ auto inputShape = inShapedType.getShape();
+ auto outputShape = outType.getShape();
+ int64_t inputDimOnAxis = inputShape[scatterDim];
+ int64_t outputDimOnAxis = outputShape[scatterDim];
+
+ for (size_t i = 0; i < outputShape.size(); ++i)
+ if (outputShape[i] != inputShape[i] && i != (size_t)scatterDim)
+ return op.emitError(
+ "Result and input shapes must match along non-scatter axes.");
+ if (outputDimOnAxis == 0)
+ return op.emitError(
+ "Output size along the scatter axis must be non-zero.");
+ if (inputDimOnAxis % outputDimOnAxis != 0)
+ return op.emitError(
+ "Input size along the scatter axis must be an exact "
+ "multiple of the output size along the scatter axis.");
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError("Result must be a statically shaped memref in "
+ "contiguous row-major layout.");
+
+ int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
+
+ // Verify that nRanks matches the number of devices along the grid axes.
+ int64_t gridGroupSize =
+ collectiveProcessGroupSize(gridAxes, gridOp->getShape());
+ if (nRanks != gridGroupSize)
+ return op.emitError()
+ << "Expected the scatter factor (" << nRanks
+ << ") to match the number of devices along grid_axes ("
+ << gridGroupSize << ").";
+
+ // Get the right communicator.
+ Value comm = getComm(*gridOp, gridAxes, ib);
+
+ Value mpiInput;
+ if (scatterDim == 0) {
+ // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
+ // Input must be contiguous for MPI.
+ Value input = getAsMemref(rawInput, ib);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError("Input must be a statically shaped memref in "
+ "contiguous row-major layout.");
+ mpiInput = input;
+ } else {
+ // For scatter_dim != 0 we rearrange the input so the scatter factor
+ // becomes the first dimension.
+ //
+ // 1. Get a tensor representation of the input (avoid memref->tensor
+ // round-trip if the input is already a tensor).
+ Value tensorInput = rawInput;
+ if (!isa<RankedTensorType>(rawInput.getType())) {
+ auto inTensorType = RankedTensorType::get(inputShape, elemType);
+ tensorInput =
+ bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
+ }
+
+ // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
+ // {d0, ..., nRanks, o_sd, ..., dN}.
+ SmallVector<int64_t> expandedShape;
+ SmallVector<ReassociationIndices> expandReassociation;
+ int64_t expandedIdx = 0;
+ for (int64_t i = 0; i < (int64_t)inputShape.size(); ++i) {
+ if (i == scatterDim) {
+ expandedShape.push_back(nRanks);
+ expandedShape.push_back(outputDimOnAxis);
+ expandReassociation.push_back({expandedIdx, expandedIdx + 1});
+ expandedIdx += 2;
+ } else {
+ expandedShape.push_back(inputShape[i]);
+ expandReassociation.push_back({expandedIdx});
+ expandedIdx += 1;
+ }
+ }
+ auto expandedType = RankedTensorType::get(expandedShape, elemType);
+ tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
+ expandReassociation);
+
+ // 3. Transpose to move nRanks (at position scatterDim) to position 0:
+ // {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
+ SmallVector<int64_t> permutation, transposedShape;
+ permutation.emplace_back(scatterDim);
+ for (int64_t i = 0; i < scatterDim; ++i)
+ permutation.emplace_back(i);
+ for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
+ permutation.emplace_back(i);
+ for (auto p : permutation)
+ transposedShape.emplace_back(expandedShape[p]);
+
+ Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
+ tensorInput =
+ linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
+ ->getResult(0);
+
+ // 4. Materialize as contiguous memref for MPI by copying into a
+ // freshly allocated buffer.
+ auto mpiInType = MemRefType::get(transposedShape, elemType);
+ Value transposedBuf =
+ bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
+ mpiInput = memref::AllocOp::create(ib, mpiInType);
+ linalg::CopyOp::create(ib, transposedBuf, mpiInput);
+ }
+
+ // Allocate output buffer.
+ Value output = memref::AllocOp::create(ib, outType);
+ // Create the MPI ReduceScatter operation.
+ mpi::ReduceScatterBlockOp::create(
+ ib, TypeRange(), mpiInput, output,
+ getMPIReductionOp(adaptor.getReductionAttr()), comm);
+
+ // Deallocate the temporary input buffer if we allocated one.
+ if (scatterDim != 0)
+ memref::DeallocOp::create(ib, mpiInput);
+
+ // If the destination is a tensor, cast it to a tensor.
+ if (isa<RankedTensorType>(op.getType()))
+ output =
+ bufferization::ToTensorOp::create(ib, op.getType(), output, true);
+ rewriter.replaceOp(op, output);
+ return success();
+ }
+};
+
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;
@@ -1048,7 +1197,7 @@ struct ConvertShardToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
- ConvertAllGatherOp, ConvertAllReduceOp,
+ ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 98234bada09e4..a173da3db1d18 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -1416,7 +1416,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
return verifyScatterOrSliceOperandAndResultShape(
- getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
+ getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
grid.value().getShape());
}
@@ -1445,9 +1445,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
- auto scatterAxis = getScatterAxis().getSExtValue();
+ auto scatterDim = getScatterDim().getSExtValue();
return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
- scatterAxis, getGridAxes(),
+ scatterDim, getGridAxes(),
grid.value().getShape());
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
index a884764e70e92..4e3fb6db9966c 100644
--- a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRShardTransforms
- Simplifications.cpp
+ Simplify.cpp
ShardingPropagation.cpp
Partition.cpp
Transforms.cpp
@@ -26,4 +26,5 @@ add_mlir_dialect_library(MLIRShardTransforms
MLIRSupport
MLIRTensorDialect
MLIRTosaShardingInterfaceImpl
+ MLIRTransformUtils
)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
similarity index 55%
rename from mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index a17671e5408c4..bb12b430b1be2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -1,4 +1,4 @@
-//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
+//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,16 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <numeric>
@@ -20,31 +23,8 @@
namespace mlir {
namespace shard {
-void populateSimplificationPatterns(
- RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
- populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
- patterns, ReductionKind::Sum);
- populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
- patterns, ReductionKind::Sum);
-
- populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
- patterns, ReductionKind::Min);
- populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
- patterns, ReductionKind::Min);
- populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
- patterns, ReductionKind::Min);
-
- populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
- patterns, ReductionKind::Max);
- populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
- patterns, ReductionKind::Max);
- populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
- patterns, ReductionKind::Max);
-
- // TODO: add simplifications for all-gather and other collectives.
-
- populateFoldingPatterns(patterns, symbolTableCollection);
-}
+#define GEN_PASS_DEF_SHARDSIMPLIFY
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
namespace {
@@ -109,12 +89,86 @@ struct GridShapeFolder
}
};
+// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
+// same grid and grid_axes.
+//
+// AllReduceOp performs an element-wise reduction across all devices in the
+// group, and AllSliceOp then slices (scatters) the result along a tensor
+// dimension. This is exactly what ReduceScatterOp does in a single collective.
+struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ // Check if the input to AllSliceOp is produced by an AllReduceOp.
+ auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
+ if (!reduceOp)
+ return failure();
+
+ // Both ops must operate on the same grid and grid axes.
+ if (reduceOp.getGrid() != sliceOp.getGrid() ||
+ reduceOp.getGridAxes() != sliceOp.getGridAxes())
+ return failure();
+
+ // Replace with a single ReduceScatterOp.
+ rewriter.replaceOpWithNewOp<ReduceScatterOp>(
+ sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
+ sliceOp.getGridAxesAttr(), reduceOp.getInput(),
+ reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
+
+ return success();
+ }
+};
+
} // namespace
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTableCollection) {
+ populateAllReduceEndomorphismSimplifyPatterns<arith::AddFOp>(
+ patterns, ReductionKind::Sum);
+ populateAllReduceEndomorphismSimplifyPatterns<arith::AddIOp>(
+ patterns, ReductionKind::Sum);
+
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MinimumFOp>(
+ patterns, ReductionKind::Min);
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MinSIOp>(
+ patterns, ReductionKind::Min);
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MinUIOp>(
+ patterns, ReductionKind::Min);
+
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MaximumFOp>(
+ patterns, ReductionKind::Max);
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MaxSIOp>(
+ patterns, ReductionKind::Max);
+ populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
+ patterns, ReductionKind::Max);
+
+ patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+
+ // TODO: add simplify patterns for all-gather and other collectives.
+
+ populateFoldingPatterns(patterns, symbolTableCollection);
+}
+
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
}
+namespace {
+
+struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ populateSimplifyPatterns(patterns, symbolTableCollection);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
} // namespace shard
} // namespace mlir
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index f3da09d05e3b8..08c3897e4e650 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -159,6 +159,40 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
return %0 : memref<3x4xf64>
}
+ // CHECK-LABEL: func.func @reduce_scatter_memref(
+ func.func @reduce_scatter_memref(
+ // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
+ %arg0 : memref<3x4xf32>) -> memref<1x4xf32> {
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 7 : i32
+ // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<1x4xf32>
+ // CHECK: mpi.reduce_scatter_block([[varg0]], [[valloc]], MPI_SUM, [[vnewcomm]]) : memref<3x4xf32>, memref<1x4xf32>
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : memref<3x4xf32> -> memref<1x4xf32>
+ // CHECK: return [[valloc]] : memref<1x4xf32>
+ return %0 : memref<1x4xf32>
+ }
+
+ // CHECK-LABEL: func.func @reduce_scatter_tensor_dim1(
+ func.func @reduce_scatter_tensor_dim1(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<2x12xf32>
+ %arg0 : tensor<2x12xf32>) -> tensor<2x4xf32> {
+ // CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
+ // CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
+ // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
+ // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
+ // CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
+ // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
+ // CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
+ // CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
+ // CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
+ // CHECK: return [[vout]] : tensor<2x4xf32>
+ return %0 : tensor<2x4xf32>
+ }
+
// CHECK-LABEL: func @allgather_tensor_0
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
diff --git a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
index bc911215851aa..5b11078c32bbf 100644
--- a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --shard-simplify --cse %s | FileCheck %s
shard.grid @grid_1d(shape = ?)
@@ -59,15 +59,15 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
// CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
// CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
// CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
- // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
- // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[scatter_dim_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
// CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[RESULT_scatter_dim_SIZE:.*]] = arith.divui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
- // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
- // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+ // CHECK: %[[scatter_dim_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_scatter_dim_SIZE]] : index
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[scatter_dim_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_scatter_dim_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
%0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
return %0 : tensor<?x?xf16>
diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir
index ed40dfb7237da..a3a9c592ff0ae 100644
--- a/mlir/test/Dialect/Shard/canonicalization.mlir
+++ b/mlir/test/Dialect/Shard/canonicalization.mlir
@@ -135,7 +135,7 @@ func.func @reduce_scatter_empty_grid_axes(
// CHECK-NOT: shard.reduce_scatter
%0 = shard.reduce_scatter %arg0 on @grid0
grid_axes = []
- scatter_axis = 0
+ scatter_dim = 0
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
@@ -148,7 +148,7 @@ func.func @reduce_scatter_empty_grid_axes_different_return_type(
%0 = shard.reduce_scatter %arg0 on @grid0
// CHECK-NOT: grid_axes
grid_axes = []
- scatter_axis = 0
+ scatter_dim = 0
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
@@ -160,7 +160,7 @@ func.func @reduce_scatter_default_reduction(
grid_axes = [0]
// CHECK-NOT: reduction
reduction = sum
- scatter_axis = 0
+ scatter_dim = 0
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@@ -172,7 +172,7 @@ func.func @scatter_empty_grid_axes(
// CHECK-NOT: shard.scatter
%0 = shard.scatter %arg0 on @grid0
grid_axes = []
- scatter_axis = 0
+ scatter_dim = 0
root = []
: (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
index 5a0f35b53a129..7b6ba33f84fd9 100644
--- a/mlir/test/Dialect/Shard/folding.mlir
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x?x2)
shard.grid @grid1(shape = 2x3)
diff --git a/mlir/test/Dialect/Shard/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir
index 6acac971164ed..c92932a725d1c 100644
--- a/mlir/test/Dialect/Shard/invalid.mlir
+++ b/mlir/test/Dialect/Shard/invalid.mlir
@@ -707,7 +707,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// expected-error at +1 {{Grid axes contains duplicate elements.}}
- %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_dim = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@@ -719,7 +719,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 scatter_dim = 0
: tensor<?xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@@ -731,7 +731,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
: tensor<3xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@@ -743,7 +743,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
// expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@@ -756,7 +756,7 @@ func.func @scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// expected-error at +1 {{Grid axes contains duplicate elements.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
- scatter_axis = 0 root = [0, 0]
+ scatter_dim = 0 root = [0, 0]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -769,7 +769,7 @@ func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
%0 = shard.scatter %arg0 on @grid0
- scatter_axis = 0 root = []
+ scatter_dim = 0 root = []
: (tensor<?xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@@ -782,7 +782,7 @@ func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
- scatter_axis = 0 root = [1]
+ scatter_dim = 0 root = [1]
: (tensor<3xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@@ -795,7 +795,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
- scatter_axis = 0 root = [1]
+ scatter_dim = 0 root = [1]
: (tensor<4xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -808,7 +808,7 @@ func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
- scatter_axis = 0 root = [3]
+ scatter_dim = 0 root = [3]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -821,7 +821,7 @@ func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
- scatter_axis = 0 root = [2, 2]
+ scatter_dim = 0 root = [2, 2]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
diff --git a/mlir/test/Dialect/Shard/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir
index 5265dadd2a845..5d9ea064bed44 100644
--- a/mlir/test/Dialect/Shard/ops.mlir
+++ b/mlir/test/Dialect/Shard/ops.mlir
@@ -453,10 +453,10 @@ func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
+ // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_dim = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
- reduction = max scatter_axis = 1
+ reduction = max scatter_dim = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
}
@@ -466,9 +466,9 @@ func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_dim = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
- %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_dim = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@@ -479,10 +479,10 @@ func.func @scatter_static_dimensions(
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
// CHECK-NEXT: shard.scatter %[[ARG]]
// CHECK-SAME: on @grid0 grid_axes = [2]
- // CHECK-SAME: scatter_axis = 1 root = [1]
+ // CHECK-SAME: scatter_dim = 1 root = [1]
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
%0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
- scatter_axis = 1 root = [1]
+ scatter_dim = 1 root = [1]
: (tensor<3x4xf32>) -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
@@ -493,10 +493,10 @@ func.func @scatter_dynamic_dimensions(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: shard.scatter %[[ARG]]
// CHECK-SAME: on @grid3 grid_axes = [0, 1]
- // CHECK-SAME: scatter_axis = 0 root = [1, 2]
+ // CHECK-SAME: scatter_dim = 0 root = [1, 2]
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
%0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
- scatter_axis = 0 root = [1, 2]
+ scatter_dim = 0 root = [1, 2]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -510,11 +510,11 @@ func.func @scatter_dynamic_root(
) -> tensor<1xi8> {
// CHECK-NEXT: shard.scatter %[[ARG0]]
// CHECK-SAME: on @grid0 grid_axes = [0, 2]
- // CHECK-SAME: scatter_axis = 0
+ // CHECK-SAME: scatter_dim = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
- scatter_axis = 0
+ scatter_dim = 0
root = [1, %arg1]
: (tensor<8xi8>, index) -> tensor<1xi8>
return %0 : tensor<1xi8>
diff --git a/mlir/test/Dialect/Shard/simplifications.mlir b/mlir/test/Dialect/Shard/simplify.mlir
similarity index 66%
rename from mlir/test/Dialect/Shard/simplifications.mlir
rename to mlir/test/Dialect/Shard/simplify.mlir
index 33cd490be744a..e5693a288fda6 100644
--- a/mlir/test/Dialect/Shard/simplifications.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x2)
shard.grid @grid1(shape = 4)
@@ -177,3 +177,86 @@ func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
%0 = arith.maxsi %extracted, %c1_i64 : i64
return %0 : i64
}
+
+// -----
+// AllReduceOp + AllSliceOp -> ReduceScatterOp tests
+// -----
+
+// Basic case: all_slice(all_reduce(x)) with matching grid and axes folds
+// into reduce_scatter.
+// CHECK-LABEL: func.func @all_reduce_all_slice_to_reduce_scatter
+func.func @all_reduce_all_slice_to_reduce_scatter(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: return %[[RS]]
+ return %1 : tensor<1x8xf32>
+}
+
+// Verify non-default reduction kind is preserved.
+// CHECK-LABEL: func.func @all_reduce_all_slice_max_reduction
+func.func @all_reduce_all_slice_max_reduction(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<4x8xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: return %[[RS]]
+ return %1 : tensor<1x8xf32>
+}
+
+// Slice on a different tensor axis than the reduce axes.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_slice_axis
+func.func @all_reduce_all_slice_different_slice_axis(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1] : tensor<4x8xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 1
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK: return %[[RS]]
+ return %1 : tensor<4x4xf32>
+}
+
+// Do not fold when grids differ.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid
+func.func @all_reduce_all_slice_different_grid(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+ // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid1
+ %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: return %[[AS]]
+ return %1 : tensor<1x8xf32>
+}
+
+// Do not fold when grid_axes differ.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid_axes
+func.func @all_reduce_all_slice_different_grid_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
+ // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid0 grid_axes = [1]
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK: return %[[AS]]
+ return %1 : tensor<4x4xf32>
+}
+
+// Verify element type conversion is preserved (all_reduce input/output types may differ).
+// CHECK-LABEL: func.func @all_reduce_all_slice_type_promotion
+func.func @all_reduce_all_slice_type_promotion(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<1x8xf64> {
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf64>
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf64> -> tensor<1x8xf64>
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf64>
+ // CHECK: return %[[RS]]
+ return %1 : tensor<1x8xf64>
+}
diff --git a/mlir/test/lib/Dialect/Shard/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
index f91c54721e030..a97839b6b1ffd 100644
--- a/mlir/test/lib/Dialect/Shard/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
@@ -2,7 +2,6 @@
add_mlir_library(MLIRShardTest
TestOpLowering.cpp
TestReshardingPartition.cpp
- TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
)
diff --git a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
index 23fdad1bd624d..1d1812e4aea39 100644
--- a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
@@ -1,4 +1,4 @@
-//===- TestSimplification.cpp - Test simplification -----------------------===//
+//===- TestReshardingPartition.cpp - Test resharding partition ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
deleted file mode 100644
index 28852153f37f6..0000000000000
--- a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
+++ /dev/null
@@ -1,47 +0,0 @@
-//===- TestSimplification.cpp - Test simplification -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Shard/IR/ShardDialect.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestShardSimplificationsPass
- : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
-
- void runOnOperation() override;
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, shard::ShardDialect>();
- }
- StringRef getArgument() const final { return "test-grid-simplifications"; }
- StringRef getDescription() const final { return "Test grid simplifications"; }
-};
-} // namespace
-
-void TestShardSimplificationsPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- SymbolTableCollection symbolTableCollection;
- shard::populateSimplificationPatterns(patterns, symbolTableCollection);
- [[maybe_unused]] LogicalResult status =
- applyPatternsGreedily(getOperation(), std::move(patterns));
- assert(succeeded(status) && "Rewrite patters application did not converge.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestShardSimplificationsPass() {
- PassRegistration<TestShardSimplificationsPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a427132247e6d..564bb700b53e3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -131,7 +131,6 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMemRefToLLVMWithTransforms();
void registerTestReshardingPartitionPass();
-void registerTestShardSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
@@ -281,7 +280,6 @@ static void registerTestPasses() {
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMemRefToLLVMWithTransforms();
mlir::test::registerTestReshardingPartitionPass();
- mlir::test::registerTestShardSimplificationsPass();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();
>From 6767033a1b1a3dd00fa81a4a8116cd420f035532 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 08:56:09 -0800
Subject: [PATCH 3/8] using correct recvcount for lowering mpi.reduce_scatter
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 ++--
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 19 ++++++++--------
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 5 +++--
.../lib/Dialect/Shard/Transforms/Simplify.cpp | 2 +-
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 22 +++++++++++++++----
5 files changed, 33 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index ee79026bf8214..0ed2a98545870 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -354,8 +354,8 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
}];
let arguments = (
- ins AnyMemRef : $sendbuf,
- AnyMemRef : $recvbuf,
+ ins AnyNon0RankedMemRef : $sendbuf,
+ AnyNon0RankedMemRef : $recvbuf,
MPI_ReductionOpEnum : $op,
MPI_Comm : $comm
);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 8571b39d41780..5087b96adf8ec 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -952,15 +952,14 @@ struct ReduceScatterBlockOpLowering
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- // recvCnt should be the size of the block -> we need to divide by the
- // number of ranks to get the count argument for the function call.
- Value nRanks = createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
- Value recvCnt = LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
-
- // Check that recvSize is a multiple of the number of ranks
- Value checkSize = LLVM::MulOp::create(rewriter, loc, i32, recvCnt, nRanks);
- Value sizeIsValid = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, checkSize, recvSize);
- cf::AssertOp::create(rewriter, loc, sizeIsValid, "Output buffer's size must be a multiple of the number of ranks");
+ Value nRanks =
+ createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+ Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+ Value sizeIsValid = LLVM::ICmpOp::create(
+ rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+ cf::AssertOp::create(rewriter, loc, sizeIsValid,
+ "Send buffer's size must be the receive buffer's size "
+ "times the number of ranks");
// 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
// int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
@@ -974,7 +973,7 @@ struct ReduceScatterBlockOpLowering
// replace op with function call
auto funcCall = LLVM::CallOp::create(
rewriter, loc, funcDecl,
- ValueRange{sendPtr, recvPtr, recvCnt, dataType, mpiOp, comm});
+ ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 01196dd1b6fcd..68448fcbd4427 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -649,7 +649,8 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
int64_t outputDimOnAxis = outputShape[scatterDim];
for (size_t i = 0; i < outputShape.size(); ++i)
- if (outputShape[i] != inputShape[i] && i != (size_t)scatterDim)
+ if (outputShape[i] != inputShape[i] &&
+ i != static_cast<size_t>(scatterDim))
return op.emitError(
"Result and input shapes must match along non-scatter axes.");
if (outputDimOnAxis == 0)
@@ -706,7 +707,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
SmallVector<int64_t> expandedShape;
SmallVector<ReassociationIndices> expandReassociation;
int64_t expandedIdx = 0;
- for (int64_t i = 0; i < (int64_t)inputShape.size(); ++i) {
+ for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
if (i == scatterDim) {
expandedShape.push_back(nRanks);
expandedShape.push_back(outputDimOnAxis);
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index bb12b430b1be2..c4256b9d3d80a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -102,7 +102,7 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
PatternRewriter &rewriter) const override {
// Check if the input to AllSliceOp is produced by an AllReduceOp.
auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
- if (!reduceOp)
+ if (!reduceOp || !reduceOp->hasOneUse())
return failure();
// Both ops must operate on the same grid and grid axes.
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 7e35f331b51a2..7069db44a3f58 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -120,8 +120,15 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
- // CHECK: [[v73:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
- // CHECK: llvm.call @MPI_Reduce_scatter_block([[v73]], {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ // CHECK: llvm.mul
+ // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
+ // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
+ // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
+ // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
+ // CHECK: ^[[rsb_bb1]]:
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -246,8 +253,15 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
- // CHECK: [[v63:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
- // CHECK: llvm.call @MPI_Reduce_scatter_block([[v63]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.mul
+ // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
+ // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
+ // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
+ // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
+ // CHECK: ^[[rsb_bb1]]:
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
>From 06803b0b4eab0ff6f9391269961f758e039e7f06 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 10:19:00 -0800
Subject: [PATCH 4/8] adding performance hints
---
mlir/lib/Dialect/Shard/Transforms/Simplify.cpp | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index c4256b9d3d80a..525ff007bc2f6 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -95,6 +95,17 @@ struct GridShapeFolder
// AllReduceOp performs an element-wise reduction across all devices in the
// group, and AllSliceOp then slices (scatters) the result along a tensor
// dimension. This is exactly what ReduceScatterOp does in a single collective.
+//
+// With a ring algorithm over N ranks and M elements:
+// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred
+// AllSlice: local slice, no communication
+// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred
+// So this fusion roughly halves the communication volume.
+//
+// Memory-wise, AllReduce produces a full-sized M-element result that the
+// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
+// only materializes the M/N-element local slice, reducing peak memory by
+// a factor of N.
struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
using OpRewritePattern::OpRewritePattern;
>From 130cf4178b82e0f2257611920f8466289187baa4 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 19:22:32 +0100
Subject: [PATCH 5/8] Update mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5087b96adf8ec..50817cf5e00dd 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -954,9 +954,10 @@ struct ReduceScatterBlockOpLowering
Value nRanks =
createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
- Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+ Value totalExpected =
+ LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
Value sizeIsValid = LLVM::ICmpOp::create(
- rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+ rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
cf::AssertOp::create(rewriter, loc, sizeIsValid,
"Send buffer's size must be the receive buffer's size "
"times the number of ranks");
>From 77b42fbfa7b67e9987cb2a3f3aad650e7aac3e46 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 19:24:58 +0100
Subject: [PATCH 6/8] Update mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 68448fcbd4427..830a9333ade4a 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -641,8 +641,8 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
Value rawInput = adaptor.getInput();
auto inShapedType = cast<ShapedType>(rawInput.getType());
- auto elemType = inShapedType.getElementType();
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ auto elemType = outType.getElementType();
auto inputShape = inShapedType.getShape();
auto outputShape = outType.getShape();
int64_t inputDimOnAxis = inputShape[scatterDim];
>From b7d9e6a4848d6944d38982e751306e7f12daae19 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Mar 2026 02:49:28 -0800
Subject: [PATCH 7/8] verify same elementtype
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 1 +
mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 15 +++++++++++++++
2 files changed, 16 insertions(+)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 0ed2a98545870..fb0192ba748ad 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -365,6 +365,7 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
"(`->` type($retval)^)?";
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 6cca853071dc2..e5e09e28998ba 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -15,8 +15,23 @@
using namespace mlir;
using namespace mlir::mpi;
+//===----------------------------------------------------------------------===//
+// Verifiers
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
+ if (getSendbuf().getType().getElementType() !=
+ getRecvbuf().getType().getElementType())
+ return emitOpError("sendbuf and recvbuf must have the same element type");
+ return success();
+}
+
namespace {
+//===----------------------------------------------------------------------===//
+// Canonicalization patterns
+//===----------------------------------------------------------------------===//
+
// If input memref has dynamic shape and is a cast and if the cast's input has
// static shape, fold the cast's static input into the given operation.
template <typename OpT>
>From 7dfe9b283965b398d0fef4c6357d00115d444a18 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Mar 2026 07:04:29 -0800
Subject: [PATCH 8/8] separating mpitollvm tests
---
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 493 +++++++++---------
1 file changed, 261 insertions(+), 232 deletions(-)
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 7069db44a3f58..73ad2d8f9299f 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -1,139 +1,151 @@
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
// COM: Test MPICH ABI
-// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
-// CHECK: llvm.func @MPI_Finalize() -> i32
-// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
+// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK-DAG: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+// CHECK-DAG: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+// CHECK-DAG: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
- // CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
- func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
-
- // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
- // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_init_finalize_mpich
+ func.func @test_init_finalize_mpich() {
+ // CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
+ // CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
+ %1 = mpi.finalize : !mpi.retval
+ return
+ }
- // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ // CHECK-LABEL: llvm.func @test_comm_rank_mpich
+ func.func @test_comm_rank_mpich() {
+ // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
-
- // CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
- // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
+ // CHECK: [[v1:%.*]] = llvm.trunc [[v0]] : i64 to i32
+ // CHECK: [[v2:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v3:%.*]] = llvm.alloca [[v2]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: llvm.call @MPI_Comm_rank([[v1]], [[v3]]) : (i32, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ return
+ }
- // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
- // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
- // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
- // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK-LABEL: llvm.func @test_send_mpich
+ func.func @test_send_mpich(%arg0: memref<100xf32>) {
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: llvm.call @MPI_Comm_rank
+ // CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ // COM: Test send without retval
+ // CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
+ // CHECK: [[v8:%.*]] = llvm.mul [[v7]]
+ // CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
+ // CHECK: = llvm.call @MPI_Send([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
-
- // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
- // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
- // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // COM: Test send with retval
+ // CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+ return
+ }
- // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
- // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
- // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
- // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_recv_mpich
+ func.func @test_recv_mpich(%arg0: memref<100xf32>) {
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: llvm.call @MPI_Comm_rank
+ // CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ // COM: Test recv without retval
+ // CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
+ // CHECK: [[v8:%.*]] = llvm.mul [[v7]]
+ // CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
+ // CHECK: [[v11:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: [[v12:%.*]] = llvm.inttoptr [[v11]] : i64 to !llvm.ptr
+ // CHECK: = llvm.call @MPI_Recv([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]], [[v12]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
-
- // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
- // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
- // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
- // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // COM: Test recv with retval
+ // CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
-
- // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
+ return
+ }
+
+ // CHECK-LABEL: llvm.func @test_comm_split_mpich
+ func.func @test_comm_split_mpich() {
+ // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
- // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
+ // CHECK: [[v2:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
- // CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
- // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
- // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
+ // CHECK: [[v3:%.*]] = llvm.trunc [[v0]] : i64 to i32
+ // CHECK: [[v4:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v5:%.*]] = llvm.alloca [[v4]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: llvm.call @MPI_Comm_split([[v3]], [[v1]], [[v2]], [[v5]]) : (i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: llvm.load [[v5]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+ return
+ }
+ // CHECK-LABEL: llvm.func @test_allgather_mpich
+ func.func @test_allgather_mpich(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_size
- // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
- %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+ %err = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ return
+ }
- // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32
- // CHECK: [[v63:%.*]] = llvm.mul [[v63a]]
- // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32
- // CHECK: [[v68:%.*]] = llvm.mul [[v68a]]
- // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
- // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
- // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
- // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ // CHECK-LABEL: llvm.func @test_allreduce_mpich
+ func.func @test_allreduce_mpich(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ // CHECK: [[v2:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v3:%.*]] = llvm.mul
+ // CHECK: [[v4:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v5:%.*]] = llvm.mlir.constant(-1 : i64) : i64
+ // CHECK: [[v6:%.*]] = llvm.inttoptr [[v5]] : i64 to !llvm.ptr
+ // CHECK: [[v7:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[v8:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK: [[v9:%.*]] = llvm.trunc [[v1]] : i64 to i32
+ // CHECK: llvm.call @MPI_Allreduce([[v6]], [[v4]], [[v3]], [[v7]], [[v8]], [[v9]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+ return
+ }
+ // CHECK-LABEL: llvm.func @test_reduce_scatter_block_mpich
+ func.func @test_reduce_scatter_block_mpich(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: llvm.mul
- // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
- // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
- // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
- // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
- // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
- // CHECK: ^[[rsb_bb1]]:
- // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
+ // CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
+ // CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
+ // CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
+ // CHECK: ^[[bb1]]:
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
-
- // CHECK: llvm.call @MPI_Finalize() : () -> i32
- %3 = mpi.finalize : !mpi.retval
-
return
}
}
@@ -141,143 +153,160 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// -----
// COM: Test OpenMPI ABI
-// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
-// CHECK: llvm.func @MPI_Finalize() -> i32
-// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
-// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
-// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
-// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
-// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
+// CHECK-DAG: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
+// CHECK-DAG: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
+// CHECK-DAG: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-DAG: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
+// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
- // CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
- func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
-
- // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
- // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_init_finalize_openmpi
+ func.func @test_init_finalize_openmpi() {
+ // CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
+ // CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
+ %1 = mpi.finalize : !mpi.retval
+ return
+ }
+ // CHECK-LABEL: llvm.func @test_comm_rank_openmpi
+ func.func @test_comm_rank_openmpi() {
%comm = mpi.comm_world : !mpi.comm
- // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
- // CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
- // CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
- // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
+ // CHECK: [[v2:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
+ // CHECK: [[v3:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v4:%.*]] = llvm.alloca [[v3]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: llvm.call @MPI_Comm_rank([[v2]], [[v4]]) : (!llvm.ptr, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ return
+ }
- // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
- // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
- // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
- // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_send_openmpi
+ func.func @test_send_openmpi(%arg0: memref<100xf32>) {
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
+ // CHECK: llvm.call @MPI_Comm_rank
+ // CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ // COM: Test send without retval
+ // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
+ // CHECK: [[v9:%.*]] = llvm.mul [[v8]]
+ // CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
+ // CHECK: = llvm.call @MPI_Send([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
-
- // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
- // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
- // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+ // COM: Test send with retval
+ // CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+ return
+ }
- // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
- // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
- // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
- // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_recv_openmpi
+ func.func @test_recv_openmpi(%arg0: memref<100xf32>) {
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
+ // CHECK: llvm.call @MPI_Comm_rank
+ // CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
+ // COM: Test recv without retval
+ // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
+ // CHECK: [[v9:%.*]] = llvm.mul [[v8]]
+ // CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
+ // CHECK: [[v12:%.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: [[v13:%.*]] = llvm.inttoptr [[v12]] : i64 to !llvm.ptr
+ // CHECK: = llvm.call @MPI_Recv([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]], [[v13]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
-
- // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
- // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
- // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
- // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+ // COM: Test recv with retval
+ // CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
-
+ return
+ }
+
+ // CHECK-LABEL: llvm.func @test_comm_split_openmpi
+ func.func @test_comm_split_openmpi() {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
+ // CHECK: [[v2:%.*]] = llvm.mlir.constant(10 : i32) : i32
+ %color = arith.constant 10 : i32
+ // CHECK: [[v3:%.*]] = llvm.mlir.constant(22 : i32) : i32
+ %key = arith.constant 22 : i32
+ // CHECK: [[v4:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
+ // CHECK: [[v5:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v6:%.*]] = llvm.alloca [[v5]] x !llvm.ptr : (i32) -> !llvm.ptr
+ // CHECK: llvm.call @MPI_Comm_split([[v4]], [[v2]], [[v3]], [[v6]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+ // CHECK: llvm.load [[v6]] : !llvm.ptr -> i32
+ %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+ return
+ }
+
+ // CHECK-LABEL: llvm.func @test_allgather_openmpi
+ func.func @test_allgather_openmpi(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.udiv {{.*}} : i32
- // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
+ return
+ }
- // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32
- // CHECK: [[v53:%.*]] = llvm.mul [[v53a]]
- // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32
- // CHECK: [[v58:%.*]] = llvm.mul [[v58a]]
- // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
- // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
- // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK-LABEL: llvm.func @test_allreduce_openmpi
+ func.func @test_allreduce_openmpi(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
+ // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
+ // CHECK: [[v3:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v4:%.*]] = llvm.mul
+ // CHECK: [[v5:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v6:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: [[v7:%.*]] = llvm.inttoptr [[v6]] : i64 to !llvm.ptr
+ // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK: [[v9:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+ // CHECK: [[v10:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
+ // CHECK: llvm.call @MPI_Allreduce([[v7]], [[v5]], [[v4]], [[v8]], [[v9]], [[v10]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+ return
+ }
+ // CHECK-LABEL: llvm.func @test_reduce_scatter_block_openmpi
+ func.func @test_reduce_scatter_block_openmpi(%arg0: memref<100xf32>) {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: llvm.mul
- // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
- // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
- // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
- // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
- // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
- // CHECK: ^[[rsb_bb1]]:
- // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
+ // CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
+ // CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
+ // CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+ // CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
+ // CHECK: ^[[bb1]]:
+ // CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
-
- // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
- %color = arith.constant 10 : i32
- // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
- %key = arith.constant 22 : i32
- // CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
- // CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
- // CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
- %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
-
- // CHECK: llvm.call @MPI_Finalize() : () -> i32
- %3 = mpi.finalize : !mpi.retval
-
return
}
}
@@ -285,13 +314,13 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// -----
module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } {
- // CHECK: llvm.func @mpi_test_fold
- func.func @mpi_test_fold(%arg0: memref<100xf32>) {
- // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ // CHECK-LABEL: llvm.func @test_fold
+ func.func @test_fold(%arg0: memref<100xf32>) {
+ // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK-NOT: llvm.call @MPI_Comm_size
- // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+ // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
return
}
More information about the Mlir-commits
mailing list