[Mlir-commits] [mlir] [mlir][mesh] Dedublicate iterator type and partial type information (PR #81920)

Boian Petkantchin llvmlistbot at llvm.org
Thu Feb 15 13:08:10 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/81920

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has only parallel and reduction values.

Rename `Partial` to `ReductionKind`.

Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.

>From 956df9de2941b925a13711cdb04b01f694261e65 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Mon, 12 Feb 2024 17:34:52 -0800
Subject: [PATCH] [mlir][mesh] Dedublicate iterator type and partial type
 information

The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary information.

Use `utils::IteratorType` instead of `mesh::IteratorType`.
It now has only parallel and reduction values.

Rename `Partial` to `ReductionKind`.

Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 40 +++++--------------
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  9 ++---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  6 +--
 .../Mesh/Interfaces/ShardingInterface.h       |  1 +
 .../Mesh/Interfaces/ShardingInterface.td      | 31 +++++++++++---
 .../Mesh/Interfaces/ShardingInterfaceImpl.h   | 17 ++++----
 .../Dialect/Mesh/Transforms/Simplifications.h |  2 +-
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt       |  1 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 29 +-------------
 .../Mesh/Interfaces/ShardingInterface.cpp     | 36 +++++++++--------
 .../Mesh/Transforms/Simplifications.cpp       | 16 ++++----
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 13 ++++--
 12 files changed, 93 insertions(+), 108 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 04929f4869273d..fc2acc70381ef7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
 
-def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
+  "Reduction of an iterator/mesh dimension.", [
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
-def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
-// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
-// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
-// is partial.
-def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
-  I32EnumAttrCase<"Parallel", 1, "parallel">,
-  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
-  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
-  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
-  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
-  I32EnumAttrCase<"Invalid", 100, "invalid">
-]> {
-    let genSpecializedAttr = 0;
-    let cppNamespace = "::mlir::mesh";
-}
-
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
      "The mesh on which tensors are sharded.">:$mesh,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
-    OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+    OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
   );
 
   let summary = "Attribute that extends tensor type to distributed tensor type.";
 
   let description = [{
-    The MeshSharding attribute could be used in the encoding of a
-    `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+    The MeshSharding attribute is used in a `mesh.shard` operation.
+    It specifies how a tensor is sharded and distributed across the process
+    mesh.
 
     1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
     mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     4. `partial_type`: indicates the reduction type of the possible all-reduce
     op. It has 4 possible values:
-    - `partial_sum`: denotes it's an all-reduce-sum
-    - `partial_max`: denotes it's an all-reduce-max
-    - `partial_min`: denotes it's an all-reduce-min
-    - `partial_generic`: denotes that the all-reduce type is complex and cannot
-    be represented merely by a simple sum, max, or min. The exact reduction
-    computation may be derived from the semantics of the corresponding operation
-    or from the reduction computation IR
+    `generic`: is not an allowed value inside a shard attribute.
 
     Example:
 
@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
-                     "mesh::Partial": $partial_type), [{
+                     "mesh::ReductionKind": $partial_type), [{
       SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
                   split_axes, [&](ArrayRef<MeshAxis> array) {
           return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>,
     AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
+      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fb9425b96e68e2..4569b77441c3f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -38,9 +39,9 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 namespace mlir {
 namespace mesh {
 
-bool isReductionLoop(IteratorType iType);
-
-bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+inline bool isReductionLoop(utils::IteratorType iType) {
+  return iType == utils::IteratorType::reduction;
+}
 
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
@@ -48,8 +49,6 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
     array.pop_back();
 }
 
-Partial getPartialTypeFromReduction(IteratorType iType);
-
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshShardingAttr attr) {
   return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 96636d5347ff6e..8ba7c111aea6bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
     AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
     IndexAttr:$scatter_axis
   ));
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index cc90ddd40a6222..c47a7ddd3f9cc3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 4afb1c36a72f7b..1f75135f42882f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           output tensors.
 
           Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
-          types are parallel, parallel, reduction-sum. This indicates that M and
+          types are parallel, parallel, reduction. This indicates that M and
           N are traversed in parallel, while the K dimension is used for
           reduction.
-
-          Example 2: A softmax op's loop iterator types are parallel and
-          invalid. The second dimension is considered as invalid because it is
-          neither parallel nor any kind of reduction. 
         }],
-        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*retType=*/"SmallVector<mlir::utils::IteratorType>",
         /*methodName=*/"getLoopIteratorTypes",
         /*args=*/(ins),
         /*methodBody=*/"",
         /*defaultImplementation=*/"return {};"
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the kind of all reduction loop iterators.
+          The order is the same as the same as the result from
+          `getLoopIteratorTypes`.
+
+          Example 1:
+          iterator types =  (parallel, reduction, parallel, reduction)
+                                             ||                   ||
+          reduction kinds = (                sum,                 max)
+
+          Example 2:
+          A softmax op's loop iterator types are parallel and
+          reduction.
+          The reduction iterator will be of kind `generic`, since it is non of
+          the available presets.
+        }],
+        /*retType=*/"SmallVector<ReductionKind>",
+        /*methodName=*/"getReductionLoopIteratorKinds",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Return the indexing maps attribute within the current operation.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 8108386c2e0437..ffc9b6fb18be53 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -36,8 +36,9 @@ template <typename Op>
 struct IndependentParallelIteratorDomainShardingInterface
     : public ShardingInterface::ExternalModel<
           IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
-    SmallVector<IteratorType> iterTypes;
+  SmallVector<utils::IteratorType>
+  getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<utils::IteratorType> iterTypes;
     for (Type t : operation->getOperandTypes()) {
       populateIteratorTypes(t, iterTypes);
     }
@@ -65,8 +66,9 @@ struct IndependentParallelIteratorDomainShardingInterface
   }
 
 private:
-  void populateIteratorTypes(Type t,
-                             SmallVector<IteratorType> &iterTypes) const {
+  void
+  populateIteratorTypes(Type t,
+                        SmallVector<utils::IteratorType> &iterTypes) const {
     RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
     if (!rankedTensorType) {
       return;
@@ -74,7 +76,7 @@ struct IndependentParallelIteratorDomainShardingInterface
 
     iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
     for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
-      iterTypes.push_back(IteratorType::Parallel);
+      iterTypes.push_back(utils::IteratorType::parallel);
     }
   }
 };
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
 struct ElementwiseShardingInterface
     : public ShardingInterface::ExternalModel<
           ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     Value val = op->getOperand(0);
     auto type = val.getType().dyn_cast<RankedTensorType>();
     if (!type)
       return {};
-    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    SmallVector<utils::IteratorType> types(type.getRank(),
+                                           utils::IteratorType::parallel);
     return types;
   }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f438465251bb06..c64da29ca64123 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -38,7 +38,7 @@ namespace mesh {
 // the algebraic structure.
 template <typename AlgebraicOp>
 void populateAllReduceEndomorphismSimplificationPatterns(
-    RewritePatternSet &patterns, Partial reduction) {
+    RewritePatternSet &patterns, ReductionKind reduction) {
   auto getEndomorphismOpOperand = [](Operation *op) {
     auto allReduceOp = llvm::cast<AllReduceOp>(op);
     return &allReduceOp.getInputMutable();
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 678a25f1c3cf58..45ac9edb280bc9 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRSupport
   MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3291010d27428a..838255cf5a5ba3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
   return success();
 }
 
-bool mesh::isReductionLoop(IteratorType iType) {
-  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
-}
-
-bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
-  return (partial == Partial::Generic &&
-          iType == IteratorType::ReductionGeneric) ||
-         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
-         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
-         (partial == Partial::Min && iType == IteratorType::ReductionMin);
-}
-
-Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
-  switch (iType) {
-  case IteratorType::ReductionGeneric:
-    return Partial::Generic;
-  case IteratorType::ReductionSum:
-    return Partial::Sum;
-  case IteratorType::ReductionMax:
-    return Partial::Max;
-  case IteratorType::ReductionMin:
-    return Partial::Min;
-  default:
-    llvm_unreachable("No corresponding partial type can be found");
-  }
-}
-
 template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                          FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
-                         ArrayRef<MeshAxis> partialAxes, Partial) {
+                         ArrayRef<MeshAxis> partialAxes, ReductionKind) {
   // TODO: At present mesh symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index b8b3841d947abd..fe3d7c44413fef 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/Debug.h"
@@ -163,7 +164,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
       return failure();
 
   // check loop types
-  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
   if (loopTypes.size() == 0)
     return failure();
 
@@ -198,7 +199,7 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   getOperation()->print(os);
   os << "\n";
   os << "loop types: [";
-  for (IteratorType type : getLoopIteratorTypes()) {
+  for (utils::IteratorType type : getLoopIteratorTypes()) {
     os << stringifyEnum(type) << " ";
   }
   os << "]\n";
@@ -257,12 +258,12 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
   if (failed(shardingOp.verifyShardingInterfaceImpl()))
     return op->emitOpError() << "invalid sharding interface implementation";
-  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes =
+      shardingOp.getLoopIteratorTypes();
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
   llvm::SmallVector<MeshAxis> partialMeshAxes;
-  Partial partialType;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
 
@@ -294,7 +295,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
                                     "supported at present";
-      partialType = shardAttr.getPartialType();
       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
       // Add all the reduction loop indices to `visitedLoopIndices` if
       // `partialAxes` is not empty
@@ -370,8 +370,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     if (!anyNonEmptyReductionLoop) {
       bool filled = false;
       for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
-        if (isReductionLoop(loopTypes[idx]) &&
-            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+        if (isReductionLoop(loopTypes[idx])) {
           std::ignore = fillShardingOption(op, shardingOption, nullptr,
                                            partialMeshAxes, idx);
           filled = true;
@@ -398,7 +397,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
-                                ArrayRef<IteratorType> loopTypes) {
+                                ArrayRef<utils::IteratorType> loopTypes,
+                                ArrayRef<ReductionKind> reductionLoopKinds) {
   FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
       getMeshShardingAttr(result);
   if (succeeded(maybeSharding) && !maybeSharding->first)
@@ -421,11 +421,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 
   // process the partial axes
   // partialType will be ignored if partialAxes is empty
-  Partial partialType = Partial::Sum;
+  ReductionKind partialType = ReductionKind::Sum;
+  size_t reductionLoopKindsIdx = 0;
   for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
-    IteratorType iType = std::get<0>(it);
+    utils::IteratorType iType = std::get<0>(it);
     if (isReductionLoop(iType)) {
-      Partial curPartialType = getPartialTypeFromReduction(iType);
+      ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
+      ++reductionLoopKindsIdx;
       if (!partialAxes.empty())
         assert(partialType == curPartialType &&
                "Only one reduction type is supported");
@@ -450,8 +452,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 // in `shardingOption`, `map`, and `loopTypes`.
 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
                                 const ShardingOption &shardingOption,
-                                AffineMap map,
-                                ArrayRef<IteratorType> loopTypes) {
+                                AffineMap map) {
   auto maybeShardingAttr = getMeshShardingAttr(opOperand);
   if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
     return success();
@@ -496,7 +497,10 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
 LogicalResult mesh::detail::defaultAddShardingAnnotations(
     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
-  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<utils::IteratorType> loopTypes =
+      shardingOp.getLoopIteratorTypes();
+  SmallVector<ReductionKind> reductionKinds =
+      shardingOp.getReductionLoopIteratorKinds();
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
 
@@ -504,14 +508,14 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
   for (OpResult result : op->getResults()) {
     if (failed(addShardOp(b, result, shardingOption,
                           maps[numOperands + result.getResultNumber()],
-                          loopTypes)))
+                          loopTypes, reductionKinds)))
       return failure();
   }
 
   // 2. add mesh.shard ops for all operands
   for (OpOperand &opOperand : op->getOpOperands()) {
     if (failed(addShardOp(b, opOperand, shardingOption,
-                          maps[opOperand.getOperandNumber()], loopTypes)))
+                          maps[opOperand.getOperandNumber()])))
       return failure();
   }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 7fcac2312444f3..a29551a43cb843 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -26,23 +26,23 @@ namespace mesh {
 void populateSimplificationPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
   populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
-      patterns, Partial::Sum);
+      patterns, ReductionKind::Sum);
   populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
-      patterns, Partial::Sum);
+      patterns, ReductionKind::Sum);
 
   populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
-      patterns, Partial::Min);
+      patterns, ReductionKind::Min);
   populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
-      patterns, Partial::Min);
+      patterns, ReductionKind::Min);
   populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
-      patterns, Partial::Min);
+      patterns, ReductionKind::Min);
 
   populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
-      patterns, Partial::Max);
+      patterns, ReductionKind::Max);
   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
-      patterns, Partial::Max);
+      patterns, ReductionKind::Max);
   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
-      patterns, Partial::Max);
+      patterns, ReductionKind::Max);
 
   // TODO: add simplifications for all-gather and other collectives.
 
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index fefbca8c2e091a..06a441dbeaf150 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -31,17 +31,22 @@ namespace {
 // (d0, d1, d2, d3) -> (d0, d1, d2)
 struct MatMulOpSharding
     : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
     if (!tensorType)
       return {};
 
-    SmallVector<IteratorType> types(tensorType.getRank() + 1,
-                                    IteratorType::Parallel);
-    types[tensorType.getRank()] = IteratorType::ReductionSum;
+    SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
+                                           utils::IteratorType::parallel);
+    types[tensorType.getRank()] = utils::IteratorType::reduction;
     return types;
   }
 
+  SmallVector<ReductionKind>
+  getReductionLoopIteratorKinds(Operation *op) const {
+    return SmallVector<ReductionKind>(1, ReductionKind::Sum);
+  }
+
   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
     auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
     if (!tensorType)



More information about the Mlir-commits mailing list