[Mlir-commits] [mlir] [mlir][mesh] Dedublicate iterator type and partial type information (PR #81920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 15 13:08:44 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.
Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has only parallel and reduction values.
Rename `Partial` to `ReductionKind`.
Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
---
Patch is 24.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81920.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-30)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+4-5)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+3-3)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+1)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+25-6)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+10-7)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+1-1)
- (modified) mlir/lib/Dialect/Mesh/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+1-28)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+20-16)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+8-8)
- (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+9-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
// Mesh Enums.
//===----------------------------------------------------------------------===//
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+ "Reduction of an iterator/mesh dimension.", [
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "`<` $value `>`";
}
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
- I32EnumAttrCase<"Parallel", 1, "parallel">,
- I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
- I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
- I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
- I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
- I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::mesh";
-}
-
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
- OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+ OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);
let summary = "Attribute that extends tensor type to distributed tensor type.";
let description = [{
- The MeshSharding attribute could be used in the encoding of a
- `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+ The MeshSharding attribute is used in a `mesh.shard` operation.
+ It specifies how a tensor is sharded and distributed across the process
+ mesh.
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
- - `partial_sum`: denotes it's an all-reduce-sum
- - `partial_max`: denotes it's an all-reduce-max
- - `partial_min`: denotes it's an all-reduce-min
- - `partial_generic`: denotes that the all-reduce type is complex and cannot
- be represented merely by a simple sum, max, or min. The exact reduction
- computation may be derived from the semantics of the corresponding operation
- or from the reduction computation IR
+ `generic`: is not an allowed value inside a shard attribute.
Example:
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
- "mesh::Partial": $partial_type), [{
+ "mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
namespace mlir {
namespace mesh {
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+ return iType == utils::IteratorType::reduction;
+}
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
array.pop_back();
}
-Partial getPartialTypeFromReduction(IteratorType iType);
-
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
output tensors.
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
- types are parallel, parallel, reduction-sum. This indicates that M and
+ types are parallel, parallel, reduction. This indicates that M and
N are traversed in parallel, while the K dimension is used for
reduction.
-
- Example 2: A softmax op's loop iterator types are parallel and
- invalid. The second dimension is considered as invalid because it is
- neither parallel nor any kind of reduction.
}],
- /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*retType=*/"SmallVector<mlir::utils::IteratorType>",
/*methodName=*/"getLoopIteratorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return {};"
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the kind of all reduction loop iterators.
+ The order is the same as the same as the result from
+ `getLoopIteratorTypes`.
+
+ Example 1:
+ iterator types = (parallel, reduction, parallel, reduction)
+ || ||
+ reduction kinds = ( sum, max)
+
+ Example 2:
+ A softmax op's loop iterator types are parallel and
+ reduction.
+ The reduction iterator will be of kind `generic`, since it is non of
+ the available presets.
+ }],
+ /*retType=*/"SmallVector<ReductionKind>",
+ /*methodName=*/"getReductionLoopIteratorKinds",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
- SmallVector<IteratorType> iterTypes;
+ SmallVector<utils::IteratorType>
+ getLoopIteratorTypes(Operation *operation) const {
+ SmallVector<utils::IteratorType> iterTypes;
for (Type t : operation->getOperandTypes()) {
populateIteratorTypes(t, iterTypes);
}
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
}
private:
- void populateIteratorTypes(Type t,
- SmallVector<IteratorType> &iterTypes) const {
+ void
+ populateIteratorTypes(Type t,
+ SmallVector<utils::IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
if (!rankedTensorType) {
return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
- iterTypes.push_back(IteratorType::Parallel);
+ iterTypes.push_back(utils::IteratorType::parallel);
}
}
};
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
- SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ SmallVector<utils::IteratorType> types(type.getRank(),
+ utils::IteratorType::parallel);
return types;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
- RewritePatternSet &patterns, Partial reduction) {
+ RewritePatternSet &patterns, ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRDialectUtils
MLIRIR
MLIRSupport
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
return success();
}
-bool mesh::isReductionLoop(IteratorType iType) {
- return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
- return (partial == Partial::Generic &&
- iType == IteratorType::ReductionGeneric) ||
- (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
- (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
- (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
- switch (iType) {
- case IteratorType::ReductionGeneric:
- return Partial::Generic;
- case IteratorType::ReductionSum:
- return Partial::Sum;
- case IteratorType::ReductionMax:
- return Partial::Max;
- case IteratorType::ReductionMin:
- return Partial::Min;
- default:
- llvm_unreachable("No corresponding partial type can be found");
- }
-}
-
template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
- ArrayRef<MeshAxis> partialAxes, Partial) {
+ ArrayRef<MeshAxis> partialAxes, ReductionKind) {
// TODO: At present mesh symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
return failure();
// check loop types
- SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
if (loopTypes.size() == 0)
return failure();
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
getOperation()->print(os);
os << "\n";
os << "loop types: [";
- for (IteratorType type : getLoopIteratorTypes()) {
+ for (utils::IteratorType type : getLoopIteratorTypes()) {
os << stringifyEnum(type) << " ";
}
os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (failed(shardingOp.verifyShardingInterfaceImpl()))
return op->emitOpError() << "invalid sharding interface implementation";
- SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+ SmallVector<utils::IteratorType> loopTypes =
+ shardingOp.getLoopIteratorTypes();
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
llvm::SmallVector<MeshAxis> partialMeshAxes;
- Partial partialType;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!partialMeshAxes.empty())
return op->emitOpError() << "at most one result with partial axes is "
"supported at present";
- partialType = shardAttr.getPartialType();
partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
// Add all the reduction loop indices to `visitedLoopIndices` if
// `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!anyNonEmptyReductionLoop) {
bool filled = false;
for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
- if (isReductionLoop(loopTypes[idx]) &&
- areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+ if (isReductionLoop(loopTypes[idx])) {
std::ignore = fillShardingOption(op, shardingOption, nullptr,
partialMeshAxes, idx);
filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
- ArrayRef<IteratorType> loopTypes) {
+ ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
getMeshShardingAttr(result);
if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// process the partial axes
// partialType will be ignored if partialAxes is empty
- Partial partialType = Partial::Sum;
+ ReductionKind partialType = ReductionKind::Sum;
+ size_t reductionLoopKindsIdx = 0;
for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
- IteratorType iType = std::get<0>(it);
+ utils::IteratorType iType = std::get<0>(it);
if (isReductionLoop(iType)) {
- Partial curPartialType = getPartialTypeFromReduction(iType);
+ ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+ ++reductionLoopKindsIdx;
if (!partialAxes.empty())
assert(partialType == curPartialType &&
"Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
// in `shardingO...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81920
More information about the Mlir-commits
mailing list