[Mlir-commits] [mlir] [mlir][mesh] removing partial/reduction axes from mesh.sharding (PR #149805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 21 05:41:38 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

As discussed on discourse (87053) this PR removes partial axes from sharding annotations.

The dialect provides types and operations across two distinct domains — sharding/partitioning and data exchange — yet one operation (`mesh.sharding`) conflated the two by allowing implicit `mesh.allreduce` behavior when partial axes are specified. Beyond being conceptually unclean, this coupling complicates the analysis needed to generate sharding/partitioning plans and inflates data structures. Sharding should focus solely on defining the data layout across devices, while reductions and other communications should be treated as part of sharded operation semantics — not sharding itself. The `ShardingInterface` is the right abstraction for capturing operation-specific requirements. Its `spmdize` method can insert the appropriate communication when tensors are sharded. Moving the responsibility for adding reductions from `mesh.sharding` into `ShardingInterface.spmdize` cleanly separates concerns, simplify the sharding syntax, and reduce the burden on authors of sharding annotations.

Some examples will currently lead to more resharding than before. The partial axes annotation was used to vaoid unnecessary communication. This can (and should) als be done by an dedicated optimization pass. Parts of the necessary mechanics for this already exist.

@<!-- -->tkarna 

---

Patch is 71.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149805.diff


13 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+4-11) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+3-26) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+6-24) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+17-52) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+8-74) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+18-84) 
- (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (-48) 
- (modified) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (+1-1) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (-24) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (-55) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (-14) 
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+105-106) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+45-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 7213fde45c695..7cfe59dd957ca 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -43,8 +43,6 @@ class MeshSharding {
 private:
   ::mlir::FlatSymbolRefAttr mesh;
   SmallVector<MeshAxesAttr> split_axes;
-  SmallVector<MeshAxis> partial_axes;
-  ReductionKind partial_type = ReductionKind::Sum;
   SmallVector<int64_t> static_halo_sizes;
   SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
@@ -55,8 +53,6 @@ class MeshSharding {
   MeshSharding(Value rhs);
   static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
                           ArrayRef<MeshAxesAttr> split_axes_,
-                          ArrayRef<MeshAxis> partial_axes_ = {},
-                          ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
                           ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
@@ -64,8 +60,6 @@ class MeshSharding {
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
-  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
-  ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
   ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
     return static_sharded_dims_offsets;
@@ -79,7 +73,7 @@ class MeshSharding {
   bool operator!=(Value rhs) const;
   bool operator==(const MeshSharding &rhs) const;
   bool operator!=(const MeshSharding &rhs) const;
-  bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
+  bool equalSplitAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
   bool equalHaloSizes(const MeshSharding &rhs) const;
   bool equalShardSizes(const MeshSharding &rhs) const;
@@ -110,10 +104,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshSharding sharding) {
-  return sharding.getPartialAxes().empty() &&
-         llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
-           return axes.asArrayRef().empty();
-         });
+  return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
+    return axes.asArrayRef().empty();
+  });
 }
 
 inline mesh::MeshOp
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index ac05ee243d7be..1662885c161e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -204,7 +204,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   let description = [{
     The MeshSharding specifies how a tensor is sharded and distributed across the
     process mesh. It is typically used in a `mesh.shard` operation.
-    The operation has the follwing attributes and operands:
+    The operation has the following attributes and operands:
 
     1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
     mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -215,15 +215,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     its value is [x, y], it indicates that the tensor's i-th dimension is splitted
     along the x and y axes of the device mesh.
 
-    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
-    one along the specified mesh axes. An all-reduce should be applied to obtain
-    the complete tensor, with reduction type being specified by `partial_type`.
-
-    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
-    op. It has 4 possible values:
-    `generic`: is not an allowed value inside a shard attribute.
-
-    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+    3. [Optional] Sizes of halos to be added for each sharded tensor dimension.
     `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
     sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
     gets an additional halo of size 1 at the start of the first dimension and a halo
@@ -231,7 +223,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
     seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
     
-    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    4. [Optional] Offsets for each shard and sharded tensor dimension.
     `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
     sharded tensor dimension the offsets (starting index) of all shards in that
     dimension and an additional value for the end of the last shard are provided.
@@ -260,14 +252,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
     %sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
 
-    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
-    // it is also a partial_sum along mesh axis 1.
-    %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
-
-    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
-    // it is also a partial_max along mesh axis 1.
-    %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
-
     // Could be used for a mesh.shard op
     %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
 
@@ -287,8 +271,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
-    OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
     Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
@@ -300,7 +282,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   let assemblyFormat = [{
     $mesh
     `split_axes` `=` $split_axes
-    (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
     (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
@@ -308,12 +289,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   let builders = [
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
-                   "ArrayRef<MeshAxis>":$partial_axes,
-                   "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
-    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                   "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index ee1957aaa6a53..8208a3123050e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -187,27 +187,6 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
   return newOperands;
 }
 
-static void createAllReduceForResultWithoutPartialSharding(
-    Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
-    MeshSharding resultSharding, ReductionKind reductionKind,
-    IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
-  SmallVector<MeshAxis> allReduceMeshAxes;
-  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
-                [&resultSharding](MeshAxis axis) {
-                  return !llvm::is_contained(resultSharding.getPartialAxes(),
-                                             axis);
-                });
-  if (allReduceMeshAxes.empty()) {
-    return;
-  }
-
-  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
-  Value reducedValue = builder.create<mesh::AllReduceOp>(
-      spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
-      reductionKind);
-  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
-}
-
 static void createAllReduceForResultsWithoutPartialShardings(
     LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
     ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
@@ -215,9 +194,12 @@ static void createAllReduceForResultsWithoutPartialShardings(
   ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
   for (auto [unshardedLinalgOpResult, resultSharding] :
        llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
-    createAllReduceForResultWithoutPartialSharding(
-        unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
-        reductionKind, spmdizationMap, builder);
+    Value spmdizedLinalgOpResult =
+        spmdizationMap.lookup(unshardedLinalgOpResult);
+    Value reducedValue = builder.create<mesh::AllReduceOp>(
+        spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
+        reductionKind);
+    spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index cf506d1e7812b..61ca81dc9c4c1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -479,37 +479,23 @@ void MeshShapeOp::getAsmResultNames(
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        FlatSymbolRefAttr mesh,
                        ArrayRef<MeshAxesAttr> split_axes,
-                       ArrayRef<MeshAxis> partial_axes,
-                       mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halos,
                        ArrayRef<int64_t> static_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
-      ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
-      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
 }
 
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
-                       FlatSymbolRefAttr mesh,
-                       ArrayRef<MeshAxesAttr> split_axes) {
-  return build(
-      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
-      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
-      {}, {}, {}, {});
-}
-
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
                        ArrayRef<int64_t> static_halos,
                        ArrayRef<int64_t> static_offsets) {
-  return build(
-      b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
-      MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
-      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+  return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
+               MeshAxesArrayAttr::get(b.getContext(), split_axes),
+               ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
+               ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
+               {});
 }
 
 void ShardingOp::build(
@@ -522,8 +508,7 @@ void ShardingOp::build(
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
   dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
-      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
-      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
 }
@@ -533,11 +518,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
 
   build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
         MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
-        from.getPartialAxes().empty()
-            ? DenseI16ArrayAttr()
-            : b.getDenseI16ArrayAttr(from.getPartialAxes()),
-        ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
-                                             from.getPartialType()),
         from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -566,9 +546,6 @@ LogicalResult ShardingOp::verify() {
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }
-  if (getPartialAxes().has_value() &&
-      failed(checkMeshAxis(getPartialAxes().value())))
-    return failure();
 
   if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
     return emitOpError("halo sizes and shard offsets are mutually exclusive");
@@ -710,17 +687,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
 // MeshSharding
 //===----------------------------------------------------------------------===//
 
-bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
+bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
   if (getMesh() != rhs.getMesh()) {
     return false;
   }
 
-  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
-      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
-      !llvm::equal(getPartialAxes(), rhs.getPartialAxes())) {
-    return false;
-  }
-
   auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
   if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
                                     getSplitAxes().begin() + minSize),
@@ -768,13 +739,13 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::operator==(Value rhs) const {
-  return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+  return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
 }
 
 bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
 
 bool MeshSharding::operator==(const MeshSharding &rhs) const {
-  return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+  return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
 }
 
 bool MeshSharding::operator!=(const MeshSharding &rhs) const {
@@ -787,30 +758,26 @@ MeshSharding::MeshSharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
   auto splitAxes = shardingOp.getSplitAxes().getAxes();
-  auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
-  // If splitAxes and partialAxes are empty, use "empty" constructor.
-  if (splitAxes.empty() && partialAxes.empty()) {
+  // If splitAxes are empty, use "empty" constructor.
+  if (splitAxes.empty()) {
     *this = MeshSharding(shardingOp.getMeshAttr());
     return;
   }
-  *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
-              shardingOp.getPartialType().value_or(ReductionKind::Sum),
-              shardingOp.getStaticHaloSizes(),
-              shardingOp.getStaticShardedDimsOffsets(),
-              SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
-              SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
+  *this =
+      get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
+          shardingOp.getStaticShardedDimsOffsets(),
+          SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+          SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
 }
 
 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<MeshAxesAttr> split_axes_,
-                               ArrayRef<MeshAxis> partial_axes_,
-                               ReductionKind partial_type_,
                                ArrayRef<int64_t> static_halo_sizes_,
                                ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res(mesh_);
-  if (split_axes_.empty() && partial_axes_.empty()) {
+  if (split_axes_.empty()) {
     return res;
   }
 
@@ -825,8 +792,6 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
     llvm::copy(src, dst.begin());
   };
 
-  clone(partial_axes_, res.partial_axes);
-  res.partial_type = partial_type_;
   clone(static_halo_sizes_, res.static_halo_sizes);
   clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
   clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca2b1a52166a..6b3d49e08b549 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -271,7 +271,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
-  llvm::SmallVector<MeshAxis> partialMeshAxes;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
 
@@ -299,22 +298,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
           return failure();
       }
     }
-
-    // Handle the partial axes: at this stage, the exact loop index/indices
-    // cannot be decided because there could be multiple reduction loops.
-    ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
-    if (!partialAxes.empty()) {
-      if (!partialMeshAxes.empty())
-        return op->emitOpError() << "at most one result with partial axes is "
-                                    "supported at present";
-      partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
-      // Add all the reduction loop indices to `visitedLoopIndices` if
-      // `partialAxes` is not empty
-      for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
-        if (isReductionLoop(loopTypes[loopIdx]))
-          visitedLoopIndices.insert(loopIdx);
-      }
-    }
   }
 
   // 2. Fill sharding option based on operands
@@ -327,8 +310,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
     AffineMap map = maps[shardingIt.index()];
     unsigned numDims = map.getNumDims();
 
-    // Handle the split axes. Partial axes don't need to be handled because they
-    // only affect the defining op of the operand.
+    // Handle the split axes.
     //
     // TODO: Change to process the operands with single loop index first and
     // then the operands with multiple loop indices.
@@ -372,28 +354,6 @@ mesh::detail::defaultGetShardingOption(Operation *op,
   }
 
   // 3. Finalize sharding option
-  if (!partialMeshAxes.empty()) {
-    bool anyNonEmptyReductionLoop = llvm::any_of(
-        llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
-          SmallVector<MeshAxis> &subArray = it.value();
-          int64_t idx = it.inde...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/149805


More information about the Mlir-commits mailing list