[Mlir-commits] [mlir] b9ab888 - [mlir][shard, mpi] Lowering shard.allgather to MPI (#177202)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 02:02:08 PST 2026
Author: Frank Schlimbach
Date: 2026-01-23T11:02:02+01:00
New Revision: b9ab8885c89b80cdb638aecbd5114672ec4fdb4b
URL: https://github.com/llvm/llvm-project/commit/b9ab8885c89b80cdb638aecbd5114672ec4fdb4b
DIFF: https://github.com/llvm/llvm-project/commit/b9ab8885c89b80cdb638aecbd5114672ec4fdb4b.diff
LOG: [mlir][shard,mpi] Lowering shard.allgather to MPI (#177202)
- lowering `shard.allgather` to `mpi.allgather`
- fixing lowering of `shard.allreduce`
- minor refactoring
Added:
Modified:
mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyNon0RankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
IndexAttr:$gather_axis
));
let results = (outs
- AnyNon0RankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder);
// Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ArrayRef<GridAxis> gridAxes = {});
// Get process linear index from a multi-index along the given grid axes .
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes = {});
} // namespace shard
} // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..87ae28892fcf7 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,152 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
}
}
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+ using OpConversionPattern<CommOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SymbolTableCollection symbolTableCollection;
- auto grid = adaptor.getGrid();
- mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
- if (!gridOp)
- return op->emitError() << "No grid found for AllReduceOp";
- if (ShapedType::isDynamicShape(gridOp.getShape()))
- return op->emitError()
- << "Dynamic grid shape not supported in AllReduceOp";
-
- ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = adaptor.getInput();
- auto inputShape = cast<ShapedType>(input.getType()).getShape();
+ MemRefType getMemrefType(ShapedType tensorType) const {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ }
+ Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+ auto itype = input.getType();
// If the source is a memref, cast it to a tensor.
- if (isa<RankedTensorType>(input.getType())) {
- auto memrefType = MemRefType::get(
- inputShape, cast<ShapedType>(input.getType()).getElementType());
+ if (isa<RankedTensorType>(itype)) {
+ auto memrefType = getMemrefType(cast<ShapedType>(itype));
input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+ } else {
+ assert(isa<MemRefType>(itype) &&
+ "expected input to be of MemRefType or TensorType");
}
- MemRefType inType = cast<MemRefType>(input.getType());
+ return input;
+ }
- // Get the actual shape to allocate the buffer.
- SmallVector<OpFoldResult> shape(inType.getRank());
- for (auto i = 0; i < inType.getRank(); ++i) {
- auto s = inputShape[i];
- if (ShapedType::isDynamic(s))
- shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
- else
- shape[i] = iBuilder.getIndexAttr(s);
- }
+ FailureOr<GridOp> checkGrid(CommOp op,
+ SymbolTableCollection &symbolTableCollection,
+ bool allowDynamic = false) const {
+ GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "Missing grid symbol.";
+ if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+ return op->emitError() << "Dynamic grid shape not supported.";
+ return gridOp;
+ }
- // Allocate buffer and copy input to buffer.
- Value buffer = memref::AllocOp::create(
- iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
- linalg::CopyOp::create(iBuilder, input, buffer);
+ // Get an MPI_Comm_split for a given grid and axes.
+ // The color is the linear index of the process in the grid along the
+ // non-'grid-axes'. The key is the linear index of the process in the grid
+ // along the grid-axes.
+ Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+ ImplicitLocOpBuilder &iBuilder) const {
+ size_t gridDims = gridOp.getShape().size();
+ auto commType = mpi::CommType::get(gridOp->getContext());
+ Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
- // Get an MPI_Comm_split for the AllReduce operation.
- // The color is the linear index of the process in the grid along the
- // non-reduced axes. The key is the linear index of the process in the grid
- // along the reduced axes.
- SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
- iBuilder.getIndexType());
- SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
- .getResult();
- Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
- SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+ if (gridAxes.empty() || gridAxes.size() >= gridDims) {
+ return commWorld;
+ }
- auto redAxes = adaptor.getGridAxes();
- for (auto axis : redAxes) {
- multiKey[axis] = myMultiIndex[axis];
- myMultiIndex[axis] = zero;
+ SmallVector<GridAxis> otherAxes;
+ for (GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
+ if (!llvm::is_contained(gridAxes, i))
+ otherAxes.emplace_back(i);
}
+ SmallVector<Type> indexResultTypes(otherAxes.size(),
+ iBuilder.getIndexType());
+
Value color =
- createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+ Value key =
+ createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
- auto commType = mpi::CommType::get(op->getContext());
- Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
- auto comm =
- mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
- .getNewcomm();
-
- Value buffer1d = buffer;
- // Collapse shape to 1d if needed
- if (inType.getRank() > 1) {
- ReassociationIndices reassociation(inType.getRank());
- std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = memref::CollapseShapeOp::create(
- iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
- }
+ return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+ .getNewcomm();
+ }
+};
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = memref::AllocOp::create(iBuilder, outType);
+ linalg::CopyOp::create(iBuilder, input, buffer);
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
// Create the MPI AllReduce operation.
- mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
getMPIReductionOp(adaptor.getReductionAttr()),
comm);
- // If the destination is a memref, cast it to a tensor
+ // If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
true);
-
rewriter.replaceOp(op, buffer);
return success();
}
};
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+ using CommOpPattern::CommOpPattern;
+
+ LogicalResult
+ matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError(
+ "Expected static shaped memref in contiguous row-major layout.");
+
+ // Get the right communicator
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+ // Allocate output buffer
+ Value output = memref::AllocOp::create(iBuilder, outType);
+ // Create the MPI AllGather operation.
+ mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+ // If the destination is a tensor, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+ true);
+ rewriter.replaceOp(op, output);
+ return success();
+ }
+};
+
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -895,8 +945,8 @@ struct ConvertShardToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
- ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
- ctxt);
+ ConvertAllGatherOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
- gridOp.getSymName(), reductionGridAxes, builder);
+ builder, gridOp.getSymName(), reductionGridAxes);
Value zero = arith::ConstantIndexOp::create(builder, 0);
Value isLeadProcess = arith::CmpIOp::create(
builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
}
TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+ ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes) {
Operation::result_range processGroupShape =
GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
- ArrayRef<GridAxis> gridAxes,
- ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+ StringRef grid,
+ ArrayRef<GridAxis> gridAxes) {
return createProcessLinearIndex(
- grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
- gridAxes, builder);
+ builder, grid,
+ ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+ gridAxes);
}
} // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..9a8ad5eea1c7b 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
- // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
// CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_memref(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
// CHECK: return [[valloc]] : memref<3x4xf32>
return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_new_type(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
+
+ // CHECK-LABEL: func @allgather_tensor
+ func.func @allgather_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+ // CHECK-SAME: -> tensor<3x20xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
+ // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32>
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
+ // CHECK: return [[v2]] : tensor<3x20xf32>
+ return %0 : tensor<3x20xf32>
+ }
+
+ // CHECK-LABEL: func @allgather_memref
+ func.func @allgather_memref(
+ // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
+ // CHECK-SAME: -> memref<3x20xf32>
+ %arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
+ // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
+ // CHECK: return [[valloc]] : memref<3x20xf32>
+ return %0 : memref<3x20xf32>
+ }
}
// -----
More information about the Mlir-commits
mailing list