[Mlir-commits] [mlir] [mlir][mesh] removing partial/reduction axes from mesh.sharding (PR #149805)
Frank Schlimbach
llvmlistbot at llvm.org
Mon Jul 21 05:41:05 PDT 2025
https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/149805
As discussed on discourse (87053) this PR removes partial axes from sharding annotations.
The dialect provides types and operations across two distinct domains — sharding/partitioning and data exchange — yet one operation (`mesh.sharding`) conflated the two by allowing implicit `mesh.allreduce` behavior when partial axes are specified. Beyond being conceptually unclean, this coupling complicates the analysis needed to generate sharding/partitioning plans and inflates data structures. Sharding should focus solely on defining the data layout across devices, while reductions and other communications should be treated as part of sharded operation semantics — not sharding itself. The `ShardingInterface` is the right abstraction for capturing operation-specific requirements. Its `spmdize` method can insert the appropriate communication when tensors are sharded. Moving the responsibility for adding reductions from `mesh.sharding` into `ShardingInterface.spmdize` cleanly separates concerns, simplify the sharding syntax, and reduce the burden on authors of sharding annotations.
Some examples will currently lead to more resharding than before. The partial axes annotation was used to vaoid unnecessary communication. This can (and should) als be done by an dedicated optimization pass. Parts of the necessary mechanics for this already exist.
@tkarna
>From e0653f6643e5117f5da493212a317fe82ccedae0 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 17 Jul 2025 14:36:00 +0200
Subject: [PATCH] removing partial/reduction axes from mesh.sharding
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 15 +-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 29 +--
.../Transforms/MeshShardingInterfaceImpl.cpp | 30 +--
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 69 ++----
.../Mesh/Interfaces/ShardingInterface.cpp | 82 +------
.../Dialect/Mesh/Transforms/Spmdization.cpp | 102 ++-------
.../test/Dialect/Linalg/mesh-spmdization.mlir | 48 ----
.../Mesh/forward-sharding-propagation.mlir | 2 +-
mlir/test/Dialect/Mesh/invalid.mlir | 24 --
mlir/test/Dialect/Mesh/ops.mlir | 55 -----
.../Dialect/Mesh/resharding-spmdization.mlir | 14 --
.../Dialect/Mesh/sharding-propagation.mlir | 211 +++++++++---------
mlir/test/Dialect/Mesh/spmdization.mlir | 46 +++-
13 files changed, 207 insertions(+), 520 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 7213fde45c695..7cfe59dd957ca 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -43,8 +43,6 @@ class MeshSharding {
private:
::mlir::FlatSymbolRefAttr mesh;
SmallVector<MeshAxesAttr> split_axes;
- SmallVector<MeshAxis> partial_axes;
- ReductionKind partial_type = ReductionKind::Sum;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
@@ -55,8 +53,6 @@ class MeshSharding {
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
- ArrayRef<MeshAxis> partial_axes_ = {},
- ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
@@ -64,8 +60,6 @@ class MeshSharding {
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
- ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
- ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
return static_sharded_dims_offsets;
@@ -79,7 +73,7 @@ class MeshSharding {
bool operator!=(Value rhs) const;
bool operator==(const MeshSharding &rhs) const;
bool operator!=(const MeshSharding &rhs) const;
- bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
+ bool equalSplitAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
bool equalHaloSizes(const MeshSharding &rhs) const;
bool equalShardSizes(const MeshSharding &rhs) const;
@@ -110,10 +104,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshSharding sharding) {
- return sharding.getPartialAxes().empty() &&
- llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
- return axes.asArrayRef().empty();
- });
+ return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
+ return axes.asArrayRef().empty();
+ });
}
inline mesh::MeshOp
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index ac05ee243d7be..1662885c161e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -204,7 +204,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
let description = [{
The MeshSharding specifies how a tensor is sharded and distributed across the
process mesh. It is typically used in a `mesh.shard` operation.
- The operation has the follwing attributes and operands:
+ The operation has the following attributes and operands:
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -215,15 +215,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
along the x and y axes of the device mesh.
- 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
- one along the specified mesh axes. An all-reduce should be applied to obtain
- the complete tensor, with reduction type being specified by `partial_type`.
-
- 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
- op. It has 4 possible values:
- `generic`: is not an allowed value inside a shard attribute.
-
- 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+ 3. [Optional] Sizes of halos to be added for each sharded tensor dimension.
`halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
gets an additional halo of size 1 at the start of the first dimension and a halo
@@ -231,7 +223,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
- 6. [Optional] Offsets for each shard and sharded tensor dimension.
+ 4. [Optional] Offsets for each shard and sharded tensor dimension.
`sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
sharded tensor dimension the offsets (starting index) of all shards in that
dimension and an additional value for the end of the last shard are provided.
@@ -260,14 +252,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
// The tensor is sharded on the first dimension along axis 0 of @mesh0
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
- // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
- // it is also a partial_sum along mesh axis 1.
- %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
-
- // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
- // it is also a partial_max along mesh axis 1.
- %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
-
// Could be used for a mesh.shard op
%sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
@@ -287,8 +271,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
let arguments = (ins
FlatSymbolRefAttr:$mesh,
Mesh_MeshAxesArrayAttr:$split_axes,
- OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
- OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
Variadic<I64>:$dynamic_sharded_dims_offsets,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
@@ -300,7 +282,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
let assemblyFormat = [{
$mesh
`split_axes` `=` $split_axes
- (`partial` `=` $partial_type $partial_axes^)?
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
@@ -308,12 +289,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
let builders = [
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
- "ArrayRef<MeshAxis>":$partial_axes,
- "mesh::ReductionKind":$partial_type,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
- OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<MeshAxesAttr>":$split_axes)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index ee1957aaa6a53..8208a3123050e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -187,27 +187,6 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
return newOperands;
}
-static void createAllReduceForResultWithoutPartialSharding(
- Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
- MeshSharding resultSharding, ReductionKind reductionKind,
- IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
- SmallVector<MeshAxis> allReduceMeshAxes;
- llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
- [&resultSharding](MeshAxis axis) {
- return !llvm::is_contained(resultSharding.getPartialAxes(),
- axis);
- });
- if (allReduceMeshAxes.empty()) {
- return;
- }
-
- Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
- Value reducedValue = builder.create<mesh::AllReduceOp>(
- spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
- reductionKind);
- spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
-}
-
static void createAllReduceForResultsWithoutPartialShardings(
LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
@@ -215,9 +194,12 @@ static void createAllReduceForResultsWithoutPartialShardings(
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
for (auto [unshardedLinalgOpResult, resultSharding] :
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
- createAllReduceForResultWithoutPartialSharding(
- unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
- reductionKind, spmdizationMap, builder);
+ Value spmdizedLinalgOpResult =
+ spmdizationMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<mesh::AllReduceOp>(
+ spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
+ reductionKind);
+ spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
}
}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index cf506d1e7812b..61ca81dc9c4c1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -479,37 +479,23 @@ void MeshShapeOp::getAsmResultNames(
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh,
ArrayRef<MeshAxesAttr> split_axes,
- ArrayRef<MeshAxis> partial_axes,
- mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
- ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr mesh,
- ArrayRef<MeshAxesAttr> split_axes) {
- return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
- {}, {}, {}, {});
-}
-
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
- return build(
- b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
- MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+ return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+ MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
+ {});
}
void ShardingOp::build(
@@ -522,8 +508,7 @@ void ShardingOp::build(
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
@@ -533,11 +518,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
- from.getPartialAxes().empty()
- ? DenseI16ArrayAttr()
- : b.getDenseI16ArrayAttr(from.getPartialAxes()),
- ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
- from.getPartialType()),
from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -566,9 +546,6 @@ LogicalResult ShardingOp::verify() {
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
- if (getPartialAxes().has_value() &&
- failed(checkMeshAxis(getPartialAxes().value())))
- return failure();
if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
return emitOpError("halo sizes and shard offsets are mutually exclusive");
@@ -710,17 +687,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
// MeshSharding
//===----------------------------------------------------------------------===//
-bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
+bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
if (getMesh() != rhs.getMesh()) {
return false;
}
- if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
- (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
- !llvm::equal(getPartialAxes(), rhs.getPartialAxes())) {
- return false;
- }
-
auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
getSplitAxes().begin() + minSize),
@@ -768,13 +739,13 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
}
bool MeshSharding::operator==(Value rhs) const {
- return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+ return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
bool MeshSharding::operator==(const MeshSharding &rhs) const {
- return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+ return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
bool MeshSharding::operator!=(const MeshSharding &rhs) const {
@@ -787,30 +758,26 @@ MeshSharding::MeshSharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
- auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
- // If splitAxes and partialAxes are empty, use "empty" constructor.
- if (splitAxes.empty() && partialAxes.empty()) {
+ // If splitAxes are empty, use "empty" constructor.
+ if (splitAxes.empty()) {
*this = MeshSharding(shardingOp.getMeshAttr());
return;
}
- *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
- shardingOp.getPartialType().value_or(ReductionKind::Sum),
- shardingOp.getStaticHaloSizes(),
- shardingOp.getStaticShardedDimsOffsets(),
- SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
- SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
+ *this =
+ get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
+ shardingOp.getStaticShardedDimsOffsets(),
+ SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+ SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
- ArrayRef<MeshAxis> partial_axes_,
- ReductionKind partial_type_,
ArrayRef<int64_t> static_halo_sizes_,
ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res(mesh_);
- if (split_axes_.empty() && partial_axes_.empty()) {
+ if (split_axes_.empty()) {
return res;
}
@@ -825,8 +792,6 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
llvm::copy(src, dst.begin());
};
- clone(partial_axes_, res.partial_axes);
- res.partial_type = partial_type_;
clone(static_halo_sizes_, res.static_halo_sizes);
clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca2b1a52166a..6b3d49e08b549 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -271,7 +271,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
- llvm::SmallVector<MeshAxis> partialMeshAxes;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -299,22 +298,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
return failure();
}
}
-
- // Handle the partial axes: at this stage, the exact loop index/indices
- // cannot be decided because there could be multiple reduction loops.
- ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
- if (!partialAxes.empty()) {
- if (!partialMeshAxes.empty())
- return op->emitOpError() << "at most one result with partial axes is "
- "supported at present";
- partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
- // Add all the reduction loop indices to `visitedLoopIndices` if
- // `partialAxes` is not empty
- for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
- if (isReductionLoop(loopTypes[loopIdx]))
- visitedLoopIndices.insert(loopIdx);
- }
- }
}
// 2. Fill sharding option based on operands
@@ -327,8 +310,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
AffineMap map = maps[shardingIt.index()];
unsigned numDims = map.getNumDims();
- // Handle the split axes. Partial axes don't need to be handled because they
- // only affect the defining op of the operand.
+ // Handle the split axes.
//
// TODO: Change to process the operands with single loop index first and
// then the operands with multiple loop indices.
@@ -372,28 +354,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
}
// 3. Finalize sharding option
- if (!partialMeshAxes.empty()) {
- bool anyNonEmptyReductionLoop = llvm::any_of(
- llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
- SmallVector<MeshAxis> &subArray = it.value();
- int64_t idx = it.index();
- return isReductionLoop(loopTypes[idx]) && !subArray.empty();
- });
- if (!anyNonEmptyReductionLoop) {
- bool filled = false;
- for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
- if (isReductionLoop(loopTypes[idx])) {
- std::ignore = fillShardingOption(op, shardingOption, nullptr,
- partialMeshAxes, idx);
- filled = true;
- break;
- }
- }
- if (!filled)
- return op->emitOpError() << "no matched reduction loop found for the "
- "result's partial type";
- }
- }
removeTrailingEmptySubArray(shardingOption.shardingArray);
if (!anyShardingInResultsOrOperands)
shardingOption.empty = true;
@@ -402,11 +362,10 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// Get the sharding attributed for the given result and sharding option.
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
- AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
- ArrayRef<ReductionKind> reductionLoopKinds) {
+ AffineMap map,
+ ArrayRef<utils::IteratorType> loopTypes) {
auto resultType = cast<RankedTensorType>(result.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
- SmallVector<MeshAxis> partialAxes;
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
@@ -419,28 +378,9 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
}
- // process the partial axes
- // partialType will be ignored if partialAxes is empty
- ReductionKind partialType = ReductionKind::Sum;
- size_t reductionLoopKindsIdx = 0;
- for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
- utils::IteratorType iType = std::get<0>(it);
- if (isReductionLoop(iType)) {
- ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
- ++reductionLoopKindsIdx;
- if (!partialAxes.empty())
- assert(partialType == curPartialType &&
- "Only one reduction type is supported");
- partialType = curPartialType;
- const SmallVector<MeshAxis> &axis = std::get<1>(it);
- partialAxes.append(axis);
- }
- }
-
removeTrailingEmptySubArray(splitAxes);
return MeshSharding::get(shardingOption.mesh,
- fromArrayOfVector(result.getContext(), splitAxes),
- partialAxes, partialType);
+ fromArrayOfVector(result.getContext(), splitAxes));
}
static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
@@ -495,8 +435,6 @@ mesh::detail::defaultGetShardingAnnotations(
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
shardingOp.getLoopIteratorTypes();
- SmallVector<ReductionKind> reductionKinds =
- shardingOp.getReductionLoopIteratorKinds();
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
@@ -511,7 +449,7 @@ mesh::detail::defaultGetShardingAnnotations(
for (OpResult result : op->getResults()) {
res.push_back(getSharding(result, shardingOption,
maps[numOperands + result.getResultNumber()],
- loopTypes, reductionKinds));
+ loopTypes));
}
return res;
@@ -526,10 +464,8 @@ mesh::detail::defaultGetShardingAnnotations(
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
- ArrayRef<utils::IteratorType> loopTypes,
- ArrayRef<ReductionKind> reductionLoopKinds) {
- MeshSharding sharding =
- getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
+ ArrayRef<utils::IteratorType> loopTypes) {
+ MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes);
maybeInsertTargetShardingAnnotation(sharding, result, b);
return success();
@@ -559,8 +495,6 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
shardingOp.getLoopIteratorTypes();
- SmallVector<ReductionKind> reductionKinds =
- shardingOp.getReductionLoopIteratorKinds();
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
@@ -568,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
for (OpResult result : op->getResults()) {
if (failed(addShardOp(b, result, shardingOption,
maps[numOperands + result.getResultNumber()],
- loopTypes, reductionKinds)))
+ loopTypes)))
return failure();
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index d7b7234f69347..6632930fef3b8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -46,63 +46,6 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
});
}
-// Return the reduced value and its corresponding sharding.
-// Example:
-// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
-// targetSharding = <@mesh_1d, [[]]>
-// Then will apply all-reduce on the source value
-// and return it with the sharding <@mesh_1d, [[0]]>.
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-handlePartialAxesDuringResharding(OpBuilder &builder,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
- TypedValue<ShapedType> sourceShard) {
- if (sourceSharding.getPartialAxes().empty() &&
- targetSharding.getPartialAxes().empty()) {
- return {sourceShard, sourceSharding};
- }
- assert(targetSharding.getPartialAxes().empty() ||
- (!sourceSharding.getPartialAxes().empty() &&
- sourceSharding.getPartialType() == targetSharding.getPartialType()));
- using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
- using AxisSet = llvm::SmallDenseSet<Axis>;
- AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
- sourceSharding.getPartialAxes().end());
- AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
- targetSharding.getPartialAxes().end());
- assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
- targetShardingPartialAxesSet));
- llvm::SmallVector<MeshAxis> allReduceMeshAxes;
- llvm::copy_if(sourceShardingPartialAxesSet,
- std::back_inserter(allReduceMeshAxes),
- [&targetShardingPartialAxesSet](Axis a) {
- return !targetShardingPartialAxesSet.contains(a);
- });
- if (allReduceMeshAxes.empty()) {
- return {sourceShard, sourceSharding};
- }
-
- builder.setInsertionPointAfterValue(sourceShard);
- TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
- builder
- .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
- sourceSharding.getMeshAttr().getLeafReference(),
- allReduceMeshAxes, sourceShard,
- sourceSharding.getPartialType())
- .getResult());
-
- llvm::SmallVector<MeshAxis> remainingPartialAxes;
- llvm::copy_if(sourceShardingPartialAxesSet,
- std::back_inserter(allReduceMeshAxes),
- [&targetShardingPartialAxesSet](Axis a) {
- return targetShardingPartialAxesSet.contains(a);
- });
- MeshSharding resultSharding = MeshSharding::get(
- sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
- remainingPartialAxes, sourceSharding.getPartialType());
- return {resultValue, resultSharding};
-}
-
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
MeshSharding sourceSharding,
int64_t splitTensorAxis,
@@ -118,9 +61,8 @@ static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
targetSplitAxes.push_back(splitMeshAxis);
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(
- sourceSharding.getMeshAttr(), targetShardingSplitAxes,
- sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+ return MeshSharding::get(sourceSharding.getMeshAttr(),
+ targetShardingSplitAxes);
}
// Split a replicated tensor along a mesh axis.
@@ -239,9 +181,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(
- sourceSharding.getMeshAttr(), targetShardingSplitAxes,
- sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+ return MeshSharding::get(sourceSharding.getMeshAttr(),
+ targetShardingSplitAxes);
}
static ShapedType allGatherResultShapeInUnsplitLastAxis(
@@ -366,9 +307,8 @@ static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
targetShardingSplitAxes[targetTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(
- sourceSharding.getMeshAttr(), targetShardingSplitAxes,
- sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+ return MeshSharding::get(sourceSharding.getMeshAttr(),
+ targetShardingSplitAxes);
}
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
@@ -439,9 +379,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> sourceShard) {
// Currently handles only cases where halo sizes differ but everything else
// stays the same (from source to destination sharding).
- if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
- !sourceSharding.getPartialAxes().empty() ||
- !targetSharding.getPartialAxes().empty() ||
+ if (!sourceSharding.equalSplitAxes(targetSharding) ||
!sourceSharding.getStaticShardedDimsOffsets().empty() ||
!targetSharding.getStaticShardedDimsOffsets().empty() ||
sourceSharding.equalHaloSizes(targetSharding)) {
@@ -519,31 +457,27 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
assert(sourceShard.getType().getRank() == targetShardType.getRank());
assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
- auto [reducedSourceShard, reducedSourceSharding] =
- handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
- sourceShard);
-
- if (reducedSourceSharding == targetSharding) {
- return reducedSourceShard;
+ if (sourceSharding == targetSharding) {
+ return sourceShard;
}
TypedValue<ShapedType> targetShard;
MeshSharding actualTargetSharding;
- if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
+ if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
targetSharding.getStaticShardedDimsOffsets().empty() &&
- reducedSourceSharding.getStaticHaloSizes().empty() &&
+ sourceSharding.getStaticHaloSizes().empty() &&
targetSharding.getStaticHaloSizes().empty()) {
if (auto tryRes = tryMoveLastSplitAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- sourceUnshardedValue.getType(), reducedSourceShard)) {
+ builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
- } else if (auto tryRes = trySplitLastAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- reducedSourceShard)) {
+ } else if (auto tryRes =
+ trySplitLastAxisInResharding(builder, mesh, sourceSharding,
+ targetSharding, sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes = tryUnsplitLastAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- sourceUnshardedValue.getType(), reducedSourceShard)) {
+ builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
}
}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 9805ee4ea5525..ce12b296df1fa 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -128,54 +128,6 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
// -----
-mesh.mesh @mesh_1d(shape = 3)
-
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result
-func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result(
- // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
- %in1: tensor<4x6xi8>,
-// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
- %in2: tensor<6x8xi8>,
-// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
- %dps_out: tensor<4x8xi8>
-// CHECK-SAME: -> tensor<4x8xi8> {
-) -> tensor<4x8xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
- %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
- %sdps_out_shared2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
- // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
- // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
- // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
- // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
- // CHECK: } else {
- // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
- // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
- // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
- // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
- // CHECK: }
- // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
- // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
- outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %sharding4 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
- %res_shared1 = mesh.shard %res to %sharding4 : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to %sharding4 annotate_for_users : tensor<4x8xi8>
- // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
- return %res_shared2 : tensor<4x8xi8>
-}
-
-// -----
-
mesh.mesh @mesh_1d(shape = 4)
// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
index 98e9931b8de94..6ab711b1b653c 100644
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -33,7 +33,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:com
%6 = tensor.empty() : tensor<i32>
%7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
// CHECK: [[vreduced:%.*]] = linalg.reduce ins
- // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
+ // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding
// CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
%reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
(%in: i32, %init: i32) {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 29b900a8da4a6..2656332942382 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -36,18 +36,6 @@ func.func @mesh_axis_duplicated_same_subarray(
mesh.mesh @mesh0(shape = 2x4)
-func.func @mesh_axis_duplicated_bewteen_split_and_partial(
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error at +1 {{mesh axis duplicated}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[0] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// -----
-
-mesh.mesh @mesh0(shape = 2x4)
-
func.func @mesh_axis_negtive_in_split_part(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -58,18 +46,6 @@ func.func @mesh_axis_negtive_in_split_part(
// -----
-mesh.mesh @mesh0(shape = 2x4)
-
-func.func @mesh_axis_negtive_in_partial(
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error at +1 {{mesh axis is expected to be non-negative}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[-1] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// -----
-
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
// expected-error at +1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
%s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 3d133f2255772..c354de514fba8 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -58,61 +58,6 @@ func.func @mesh_shard_op_1st_and_3rd_dim(
return %0 : tensor<4x8x16xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_partial_max
-func.func @mesh_shard_op_partial_max(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = max [1] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0]] partial = max[1] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// CHECK-LABEL: func @mesh_shard_op_partial_min
-func.func @mesh_shard_op_partial_min(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = min [1] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0]] partial = min[1] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// CHECK-LABEL: func @mesh_shard_op_partial_generic
-func.func @mesh_shard_op_partial_generic(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = generic [1] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0]] partial = generic[1] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// CHECK-LABEL: func @mesh_shard_op_partial_sum
-func.func @mesh_shard_op_partial_sum(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
-// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
-func.func @mesh_shard_op_partial_sum_multi_axes(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
- %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1, 2] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
-
// CHECK-LABEL: func @mesh_shard_op_two_users
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index 9ceaadacd6f66..5e62c929aa4ff 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -166,17 +166,3 @@ func.func @unshard_static_axis_on_dynamic_mesh_axis(
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
-
-// CHECK-LABEL: func @partial_axis_to_full_replication
-func.func @partial_axis_to_full_replication(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index aa5fa00488f08..0881d994d60e7 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -20,8 +20,7 @@ func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
%s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%1 = mesh.shard %0 to %s1 : tensor<8x16xf32>
// CHECK-NEXT: return %[[V2]]
@@ -37,8 +36,7 @@ func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
// CHECK-NEXT: return %[[V2]]
return %1 : tensor<8x16xf32>
}
@@ -51,8 +49,7 @@ func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16x
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
%s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: return %[[V3]]
@@ -64,13 +61,12 @@ func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16x
func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
%s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<8x16xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<8x16xf32>
}
@@ -91,8 +87,7 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2:
// CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
// CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
// CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
- // CHECK-NEXT: %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]] : tensor<8x16xf32>
%2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
%s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%3 = mesh.shard %2 to %s3 : tensor<8x16xf32>
@@ -111,50 +106,50 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten
// CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
%0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32>
%s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<2x16x32xf32>
}
-// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding
+ // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
%0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding
+ // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
+ // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
%1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
+ // CHECK-NEXT: return [[vsharded_5]]
return %1 : tensor<2x16x32xf32>
}
// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
- %1 = tosa.matmul %0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
+ // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
+ %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
+ // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
+ %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
+ // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
+ %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK: return [[vsharded_6]]
+ return %0 : tensor<2x16x32xf32>
}
// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
@@ -172,7 +167,7 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
// CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
%2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %2 : tensor<2x16x32xf32>
@@ -200,8 +195,7 @@ func.func @resolve_conflicting_annotations(
// CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
%res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
- // CHECK-NEXT: %[[SRES:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SRES]] : tensor<2x2xf32>
+ // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
%sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
%res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32>
// CHECK: return %[[RES]] : tensor<2x2xf32>
@@ -209,76 +203,82 @@ func.func @resolve_conflicting_annotations(
}
// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32>
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
- // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
- // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK: %[[V0:.*]] = tosa.matmul
- %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]] : tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
+ %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
+ %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
+ // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding
+ // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
%2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S2]] : tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]], %[[ZP]], %[[ZP]]
- %3 = tosa.matmul %2, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding
+ %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding
+ // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
+ %sharded2 = mesh.shard %arg2 to %sharding : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
%4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32>
- // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial = sum [0] : !mesh.sharding
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<2x4x8xf32>
- %s5 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding
- %5 = mesh.shard %4 to %s5 annotate_for_users : tensor<2x4x8xf32>
- // CHECK: %[[V9:.*]] = mesh.shard %[[V8]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK-NEXT: return %[[V9]]
- return %5 : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_12]]
+ return %4 : tensor<2x4x8xf32>
}
// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
// CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32>
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] : tensor<2x4x8xf32>
+ // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
%s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]], %[[ZP]], %[[ZP]]
- %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]] : tensor<2x4x32xf32>
- %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding
- %2 = mesh.shard %1 to %s2 : tensor<2x4x32xf32>
- // CHECK-DAG: %[[S5:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] : !mesh.sharding
- // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
+ // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
+ %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding
+ // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
+ %arg1_s = mesh.shard %arg1 to %s1 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
+ %2 = mesh.shard %1 to %s0 : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid
+ // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
%3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V6]] to %[[S5]] : tensor<2x4x32xf32>
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]], %[[ZP]], %[[ZP]]
- %4 = tosa.matmul %3, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding
- // CHECK-NEXT: %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]] : tensor<2x4x8xf32>
- %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding
- %5 = mesh.shard %4 to %s5 : tensor<2x4x8xf32>
- // CHECK-NEXT: %[[V12:.*]] = mesh.shard %[[V11]] to %[[S0]] annotate_for_users : tensor<2x4x8xf32>
- %s6 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %6 = mesh.shard %5 to %s6 annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: return %[[V12]]
+ // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
+ %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding
+ // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
+ %arg2_s = mesh.shard %arg2 to %s2 : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %5 = mesh.shard %4 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ %6 = mesh.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_14]]
return %6 : tensor<2x4x8xf32>
}
@@ -293,8 +293,7 @@ func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16x
// CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]] : tensor<8x16xf32>
%s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
%2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
// CHECK-NEXT: return %[[V5]]
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index af4ab58ea50a3..701898cbdc74d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -270,4 +270,48 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200
%sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
// CHECK: return %[[UH]] : tensor<303x307xi64>
return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
\ No newline at end of file
+}
+
+mesh.mesh @mesh(shape = 2)
+// CHECK-LABEL: func.func @test_reduce_0d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<i32>
+ %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+ %sharded_out = mesh.shard %4 to %sharding_out : tensor<i32>
+ %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor<i32> -> tensor<i32>
+ %sharded_red = mesh.shard %reduced to %sharding_out : tensor<i32>
+ %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
+ // CHECK: return %[[all_reduce]] : tensor<i32>
+ return %sharded_ret : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @test_reduce_1d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<6xi32>
+ %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32>
+ %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK-NOT: mesh.all_reduce
+ %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32>
+ %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
+ // CHECK: return %[[reduced]] : tensor<3xi32>
+ return %sharded_ret : tensor<6xi32>
+}
More information about the Mlir-commits
mailing list