[Mlir-commits] [mlir] [mlir][shard, mpi] Adding Shard/MPI reduce_scatter (PR #184189)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 10:06:09 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
The partition pass often creates a pattern like
```
red = shard.all_reduce...grid_axes=x...
res = shard.all_slice...grid_axes=x...
```
which is basically a `reduce-scatter` operation. A good communication implementation will have lower communication costs when using `reduce-scatter` over the `allreduce/slice` one.
For this, this PR also
- introduces a simplify pass, which finds such patterns and replaces it with the equivalent `reduce-scatter`
- promotes the test-pass `test-shard-optimizations` to a proper pass and adds the new pattern
- sanitizes the `shard.reduce_scatter` op
- adds a new `mpi.reduce_scatter_block` op
- lowers `shard.reduce_scatter` to MPI
- lowers `mpi-reduce_scatter_block` to llvm
https://github.com/llvm/lighthouse/pull/58 works nicely with the changes.
---
Patch is 57.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184189.diff
23 Files Affected:
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+34)
- (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+10-10)
- (modified) mlir/include/mlir/Dialect/Shard/Transforms/Partition.h (+1-1)
- (modified) mlir/include/mlir/Dialect/Shard/Transforms/Passes.td (+17)
- (renamed) mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h (+8-8)
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+74-1)
- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+152-2)
- (modified) mlir/lib/Dialect/Shard/IR/ShardOps.cpp (+3-3)
- (modified) mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt (+2-1)
- (renamed) mlir/lib/Dialect/Shard/Transforms/Simplify.cpp (+81-27)
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+24)
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+34)
- (modified) mlir/test/Dialect/MPI/mpiops.mlir (+6)
- (modified) mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir (+6-6)
- (modified) mlir/test/Dialect/Shard/canonicalization.mlir (+4-4)
- (modified) mlir/test/Dialect/Shard/folding.mlir (+1-1)
- (modified) mlir/test/Dialect/Shard/invalid.mlir (+10-10)
- (modified) mlir/test/Dialect/Shard/ops.mlir (+10-10)
- (renamed) mlir/test/Dialect/Shard/simplify.mlir (+84-1)
- (modified) mlir/test/lib/Dialect/Shard/CMakeLists.txt (-1)
- (modified) mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp (+1-1)
- (removed) mlir/test/lib/Dialect/Shard/TestSimplifications.cpp (-47)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7e68b152fdf75..0ed2a98545870 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -333,6 +333,40 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
"(`->` type($retval)^)?";
}
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
+ let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
+ "recvcount, dtype, op, comm)`";
+ let description = [{
+ MPI_Reduce_scatter_block first performs an element-wise reduction on the
+ sendbuf across all processes in the communicator, then scatters the result
+ by distributing equal-sized blocks to each process into recvbuf.
+
+ The `op` attribute specifies the reduction operation to be performed.
+ Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
+ supported.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (
+ ins AnyNon0RankedMemRef : $sendbuf,
+ AnyNon0RankedMemRef : $recvbuf,
+ MPI_ReductionOpEnum : $op,
+ MPI_Comm : $comm
+ );
+
+ let results = (outs Optional<MPI_Retval>:$retval);
+
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+ "(`->` type($retval)^)?";
+}
+
//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6ef7c72d305ee..60f6d6fc1ffe4 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
`grid_axes` using the specified reduction method. The reduction is performed
element-wise across the tensor pieces from all devices in the group.
After reduction, the reduction result is scattered (split and distributed)
- across the device group along `scatter_axis`.
+ across the device group along `scatter_dim`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
- reduction = <max> scatter_axis = 0
+ reduction = <max> scatter_dim = 0
: tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
@@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
```
}];
let arguments = !con(commonArgs, (ins
- AnyNon0RankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
- IndexAttr:$scatter_axis
+ IndexAttr:$scatter_dim
));
let results = (outs
- AnyRankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
- `scatter_axis` `=` $scatter_axis
+ `scatter_dim` `=` $scatter_dim
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
@@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
let summary = "Scatter over a device grid.";
let description = [{
For each device group defined by `grid_axes`, the input tensor on the `root`
- device is split along axis `scatter_axis` and distributed across the group.
+ device is split along axis `scatter_dim` and distributed across the group.
The content of the input on all other (non-root) devices is ignored.
The `root` device is defined by its in-group multi-index.
@@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
```
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
- scatter_axis = 0
+ scatter_dim = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
```
@@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- IndexAttr:$scatter_axis,
+ IndexAttr:$scatter_dim,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
- `scatter_axis` `=` $scatter_axis
+ `scatter_dim` `=` $scatter_dim
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 37903765903db..ba12002b9f1eb 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index bbc6a1977b13e..575c176217e61 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
];
}
+def ShardSimplify : Pass<"shard-simplify"> {
+ let summary = "Shard simplify patterns.";
+ let description = [{
+ Applies simplification patterns on the Shard dialect operations.
+ This includes:
+ - All-reduce endomorphism simplification, e.g. transforming
+ `all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
+ - Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
+ share the same grid and grid_axes.
+ - Folding static grid shapes into constants.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "shard::ShardDialect"
+ ];
+}
+
def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
similarity index 89%
rename from mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
index 45ae758ec14c2..f3f4feffd8a71 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -37,8 +37,8 @@ namespace shard {
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
// the algebraic structure.
template <typename AlgebraicOp>
-void populateAllReduceEndomorphismSimplificationPatterns(
- RewritePatternSet &patterns, ReductionKind reduction) {
+void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
+ ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
@@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTableCollection is used to cache them.
-void populateSimplificationPatterns(
- RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);
} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 0dbc0a126a5c6..5087b96adf8ec 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -911,6 +912,78 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOpLowering
+//===----------------------------------------------------------------------===//
+
+struct ReduceScatterBlockOpLowering
+ : public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
+ Type elemType = op.getSendbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
+
+ // If input and output are the same, request in-place operation.
+ if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+ sendPtr = LLVM::ConstantOp::create(
+ rewriter, loc, i64,
+ reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
+ }
+
+ Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+ Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ Value nRanks =
+ createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+ Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+ Value sizeIsValid = LLVM::ICmpOp::create(
+ rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+ cf::AssertOp::create(rewriter, loc, sizeIsValid,
+ "Send buffer's size must be the receive buffer's size "
+ "times the number of ranks");
+
+ // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
+ // int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
+
+ // replace op with function call
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
+ ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@@ -943,7 +1016,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
SendOpLowering, RecvOpLowering, AllGatherOpLowering,
- AllReduceOpLowering>(converter);
+ AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 1db14e60c5a7f..68448fcbd4427 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -26,7 +26,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -618,6 +618,156 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
}
};
+struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
+ using CommOpPattern::CommOpPattern;
+
+ // shard.reduce_scatter reduces and then scatters along a specified
+ // scatter-dim. mpi.reduce_scatter_block always scatters along the first
+ // dimension. Hence, if scatter-dim != 0, we need to rearrange the input
+ // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
+ // transposing nRanks to the first dimension.
+
+ LogicalResult
+ matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto gridAxes = adaptor.getGridAxes();
+ int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
+
+ SymbolTableCollection symbolTableCollection;
+ FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+ if (failed(gridOp))
+ return failure();
+
+ ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+ Value rawInput = adaptor.getInput();
+ auto inShapedType = cast<ShapedType>(rawInput.getType());
+ auto elemType = inShapedType.getElementType();
+ MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+ auto inputShape = inShapedType.getShape();
+ auto outputShape = outType.getShape();
+ int64_t inputDimOnAxis = inputShape[scatterDim];
+ int64_t outputDimOnAxis = outputShape[scatterDim];
+
+ for (size_t i = 0; i < outputShape.size(); ++i)
+ if (outputShape[i] != inputShape[i] &&
+ i != static_cast<size_t>(scatterDim))
+ return op.emitError(
+ "Result and input shapes must match along non-scatter axes.");
+ if (outputDimOnAxis == 0)
+ return op.emitError(
+ "Output size along the scatter axis must be non-zero.");
+ if (inputDimOnAxis % outputDimOnAxis != 0)
+ return op.emitError(
+ "Input size along the scatter axis must be an exact "
+ "multiple of the output size along the scatter axis.");
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError("Result must be a statically shaped memref in "
+ "contiguous row-major layout.");
+
+ int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
+
+ // Verify that nRanks matches the number of devices along the grid axes.
+ int64_t gridGroupSize =
+ collectiveProcessGroupSize(gridAxes, gridOp->getShape());
+ if (nRanks != gridGroupSize)
+ return op.emitError()
+ << "Expected the scatter factor (" << nRanks
+ << ") to match the number of devices along grid_axes ("
+ << gridGroupSize << ").";
+
+ // Get the right communicator.
+ Value comm = getComm(*gridOp, gridAxes, ib);
+
+ Value mpiInput;
+ if (scatterDim == 0) {
+ // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
+ // Input must be contiguous for MPI.
+ Value input = getAsMemref(rawInput, ib);
+ MemRefType inType = cast<MemRefType>(input.getType());
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+ return op.emitError("Input must be a statically shaped memref in "
+ "contiguous row-major layout.");
+ mpiInput = input;
+ } else {
+ // For scatter_dim != 0 we rearrange the input so the scatter factor
+ // becomes the first dimension.
+ //
+ // 1. Get a tensor representation of the input (avoid memref->tensor
+ // round-trip if the input is already a tensor).
+ Value tensorInput = rawInput;
+ if (!isa<RankedTensorType>(rawInput.getType())) {
+ auto inTensorType = RankedTensorType::get(inputShape, elemType);
+ tensorInput =
+ bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
+ }
+
+ // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
+ // {d0, ..., nRanks, o_sd, ..., dN}.
+ SmallVector<int64_t> expandedShape;
+ SmallVector<ReassociationIndices> expandReassociation;
+ int64_t expandedIdx = 0;
+ for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
+ if (i == scatterDim) {
+ expandedShape.push_back(nRanks);
+ expandedShape.push_back(outputDimOnAxis);
+ expandReassociation.push_back({expandedIdx, expandedIdx + 1});
+ expandedIdx += 2;
+ } else {
+ expandedShape.push_back(inputShape[i]);
+ expandReassociation.push_back({expandedIdx});
+ expandedIdx += 1;
+ }
+ }
+ auto expandedType = RankedTensorType::get(expandedShape, elemType);
+ tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
+ expandReassociation);
+
+ // 3. Transpose to move nRanks (at position scatterDim) to position 0:
+ // {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
+ SmallVector<int64_t> permutation, transposedShape;
+ permutation.emplace_back(scatterDim);
+ for (int64_t i = 0; i < scatterDim; ++i)
+ permutation.emplace_back(i);
+ for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
+ permutation.emplace_back(i);
+ for (auto p : permutation)
+ transposedShape.emplace_back(expandedShape[p]);
+
+ Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
+ tensorInput =
+ linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
+ ->getResult(0);
+
+ // 4. ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184189
More information about the Mlir-commits
mailing list