[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)

Boian Petkantchin llvmlistbot at llvm.org
Wed May 22 09:20:29 PDT 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/84514

>From afe62eda545978358ce388a1c481ef45609f9d47 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Mar 2024 07:03:02 -0800
Subject: [PATCH 1/2] [mlir][mesh] Insert resharding during sharding
 propagation

If there are conflicts between the sharding annotations of some op, insert
resharding.
Make the Spmdization pass more forgiving to allow for more than 2 chained
`mesh.shard` ops.

Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  28 ++-
 .../Mesh/Interfaces/ShardingInterface.h       |   9 +
 .../Mesh/Interfaces/ShardingInterface.td      |  25 +-
 .../Transforms/MeshShardingInterfaceImpl.cpp  |  15 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  94 ++++++-
 .../Mesh/Interfaces/ShardingInterface.cpp     | 119 ++++++---
 .../Mesh/Transforms/ShardingPropagation.cpp   | 231 ++++++++++++++++--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |   3 -
 .../Linalg/mesh-sharding-propagation.mlir     |  34 +++
 .../Dialect/Mesh/sharding-propagation.mlir    |  38 ++-
 mlir/test/Dialect/Mesh/spmdization.mlir       |  15 ++
 12 files changed, 540 insertions(+), 73 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d9b5892e1a51..3a85bf2d552f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -151,7 +151,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let extraClassDeclaration = [{
     bool operator==(::mlir::Attribute rhs) const;
+    bool operator!=(::mlir::Attribute rhs) const;
     bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+    bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
   }];
 
   let genVerifyDecl = 1;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 4569b77441c3f..7a24c201a39a7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,15 +51,26 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshShardingAttr attr) {
-  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+  return attr.getPartialAxes().empty() &&
+         llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+           return axes.asArrayRef().empty();
+         });
 }
 
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                            SymbolTableCollection &symbolTableCollection) {
+inline mesh::MeshOp
+getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+              SymbolTableCollection &symbolTableCollection) {
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
       op, meshSymbol);
 }
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
+  assert(meshOp);
+  return meshOp;
+}
+
 // Get the corresponding mesh op using the standard attribute nomenclature.
 template <typename Op>
 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
@@ -128,6 +139,17 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
 // `sharding` in that case must be null.
 Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
 
+// Insert shard op if there is not one that already has the same sharding.
+// May insert resharding if required.
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                         OpOperand &operand,
+                                         OpBuilder &builder);
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                         OpResult result, OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+                                         OpOperand &operand,
+                                         OpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index c47a7ddd3f9cc..216d7e10296df 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -37,6 +37,11 @@ struct ShardingOption {
   ShardingOption() = default;
   ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
       : shardingArray(std::move(shardingArray)), mesh(mesh) {}
+  static ShardingOption makeEmpty() {
+    auto res = ShardingOption();
+    res.empty = true;
+    return res;
+  }
 };
 
 // This method retrieves the 'MeshShardingAttr' attribute from a given operation
@@ -56,6 +61,10 @@ defaultGetShardingOption(Operation *op,
                          ArrayRef<MeshShardingAttr> operandShardings,
                          ArrayRef<MeshShardingAttr> resultShardings);
 
+FailureOr<SmallVector<MeshShardingAttr>>
+defaultGetShardingAnnotations(Operation *op,
+                              const ShardingOption &shardingOption);
+
 LogicalResult
 defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
                               const ShardingOption &shardingOption);
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 1f75135f42882..47a74f619f56c 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -75,8 +75,11 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Given that certain operands or results of the operation may have
-          sharding annotations, this method leverages this information to deduce
-          how the operation should be sharded.
+          sharding annotations, this method leverages this information to
+          deduce how the operation should be sharded.
+          The passed sharding may be incomplete, this gives freedom for the
+          op to select the most appropriate shardings for all the operands
+          and results and the op itself.
         }],
         /*retTy=*/"FailureOr<ShardingOption>",
         /*methodName=*/"getShardingOption",
@@ -90,6 +93,24 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
             $_op.getOperation(), operandShardings, resultShardings);
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, get the operand and result
+          operations for the operands and results sharding annotations.
+          This is what shardings the operands and results need to have in order
+          to shard the op according to shardingOption.
+        }],
+        /*retTy=*/"FailureOr<SmallVector<MeshShardingAttr>>",
+        /*methodName=*/"getShardingAnnotations",
+        /*args=*/(ins
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingAnnotations(
+            $_op.getOperation(), shardingOption);
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Based on a given ShardingOption, this method adds `mesh.shard`
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..632ee6e7b5585 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -36,6 +36,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <iterator>
+#include <numeric>
 #include <optional>
 #include <utility>
 
@@ -279,6 +280,20 @@ struct StructuredOpShardingInterface
     return res;
   }
 
+  SmallVector<ReductionKind>
+  getReductionLoopIteratorKinds(Operation *op) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+    SmallVector<utils::IteratorType> iteratorTypes =
+        linalgOp.getIteratorTypesArray();
+    unsigned reductionItersCount = std::accumulate(
+        iteratorTypes.begin(), iteratorTypes.end(), 0,
+        [](unsigned count, utils::IteratorType iter) {
+          return count + (iter == utils::IteratorType::reduction);
+        });
+    mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+    return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
+  }
+
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
                         ArrayRef<MeshShardingAttr> operandShardings,
                         ArrayRef<MeshShardingAttr> resultShardings,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d4329b401df19..ec1acbbb93498 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include <algorithm>
 #include <functional>
 #include <iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
                                           FlatSymbolRefAttr meshSymbol,
                                           SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
+  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
   return type;
 }
 
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  builder.setInsertionPointAfterValue(operandValue);
+  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+  if (shardOp && shardOp.getShard() == sharding &&
+      !shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
+    return;
+  }
+
+  auto newShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ false);
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+        return use.getOwner() == operandOp && use.get() == operandValue;
+      });
+
+  if (!shardOp || shardOp.getAnnotateForUsers()) {
+    return;
+  }
+
+  auto newShardOp2 = builder.create<ShardOp>(
+      operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpResult result,
+                                                     OpBuilder &builder) {
+  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+  }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  Operation *operandSrcOp = operandValue.getDefiningOp();
+  bool isBlockArg = !operandSrcOp;
+  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+  if (shardOp && shardOp.getShard() == sharding &&
+      shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
+    return;
+  }
+
+  builder.setInsertionPoint(operandOp);
+  auto newShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ true);
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+        return use.getOwner() == operandOp && use.get() == operandValue;
+      });
+
+  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
+    // No need for resharding.
+    return;
+  }
+
+  builder.setInsertionPoint(newShardOp);
+  auto newPreceedingShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ false);
+  rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
+                             [&newShardOp](OpOperand &use) {
+                               return use.getOwner() ==
+                                      newShardOp.getOperation();
+                             });
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.mesh op
 //===----------------------------------------------------------------------===//
@@ -286,6 +370,10 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
   return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
 }
 
+bool MeshShardingAttr::operator!=(Attribute rhs) const {
+  return !(*this == rhs);
+}
+
 bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
   if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
     return false;
@@ -311,6 +399,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
                       std::mem_fn(&MeshAxesAttr::empty));
 }
 
+bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
+  return !(*this == rhs);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dbb9e667d4709..54fc91cb26427 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
@@ -388,22 +389,11 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
   return shardingOption;
 }
 
-//===----------------------------------------------------------------------===//
-// detail::defaultAddShardingAnnotations
-//===----------------------------------------------------------------------===//
-
-// To add a `mesh.shard` op for the given result, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpResult result,
-                                const ShardingOption &shardingOption,
-                                AffineMap map,
-                                ArrayRef<utils::IteratorType> loopTypes,
-                                ArrayRef<ReductionKind> reductionLoopKinds) {
-  FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
-      getMeshShardingAttr(result);
-  if (succeeded(maybeSharding) && !maybeSharding->first)
-    return success();
-
+// Get the sharding attributed for the given result and sharding option.
+MeshShardingAttr
+getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
+                     AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
+                     ArrayRef<ReductionKind> reductionLoopKinds) {
   auto resultType = cast<RankedTensorType>(result.getType());
   SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
   SmallVector<MeshAxis> partialAxes;
@@ -438,26 +428,15 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  MeshShardingAttr shardAttr = MeshShardingAttr::get(
-      b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointAfterValue(result);
-  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
-                                   shardAttr, /*annotate_for_users*/ false);
-  result.replaceAllUsesExcept(shardOp, shardOp);
-  return success();
+  return MeshShardingAttr::get(result.getContext(), shardingOption.mesh,
+                               splitAxes, partialAxes, partialType);
 }
 
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
-                                const ShardingOption &shardingOption,
-                                AffineMap map) {
-  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
-  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
-    return success();
-  Value operand = opOperand.get();
-  auto operandType = cast<RankedTensorType>(operand.getType());
+static FailureOr<MeshShardingAttr>
+getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
+                     AffineMap map) {
+  Value operandValue = opOperand.get();
+  auto operandType = cast<RankedTensorType>(operandValue.getType());
   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -483,19 +462,79 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  MeshShardingAttr shardAttr =
-      MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
+  return MeshShardingAttr::get(opOperand.get().getContext(),
+                               shardingOption.mesh, splitAxes);
+}
+
+FailureOr<SmallVector<MeshShardingAttr>>
+mesh::detail::defaultGetShardingAnnotations(
+    Operation *op, const ShardingOption &shardingOption) {
+  SmallVector<MeshShardingAttr> res;
+
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  SmallVector<utils::IteratorType> loopTypes =
+      shardingOp.getLoopIteratorTypes();
+  SmallVector<ReductionKind> reductionKinds =
+      shardingOp.getReductionLoopIteratorKinds();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<MeshShardingAttr> shardingAttr = getShardingAttribute(
+        opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
+    if (failed(shardingAttr))
+      return failure();
+    res.push_back(*shardingAttr);
+  }
+
+  for (OpResult result : op->getResults()) {
+    res.push_back(getShardingAttribute(
+        result, shardingOption, maps[numOperands + result.getResultNumber()],
+        loopTypes, reductionKinds));
+  }
+
+  return res;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<utils::IteratorType> loopTypes,
+                                ArrayRef<ReductionKind> reductionLoopKinds) {
+  MeshShardingAttr shardAttr = getShardingAttribute(
+      result, shardingOption, map, loopTypes, reductionLoopKinds);
+  maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+
+  return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+                                const ShardingOption &shardingOption,
+                                AffineMap map) {
+
+  FailureOr<MeshShardingAttr> shardAttr =
+      getShardingAttribute(opOperand, shardingOption, map);
+  if (failed(shardAttr)) {
+    return failure();
+  }
   OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPoint(opOperand.getOwner());
-  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
-                                   shardAttr, true);
-  opOperand.set(shardOp);
+  maybeInsertSourceShardingAnnotation(*shardAttr, opOperand, b);
 
   return success();
 }
 
 LogicalResult mesh::detail::defaultAddShardingAnnotations(
     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+  assert(!shardingOption.empty && shardingOption.mesh);
+
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   SmallVector<utils::IteratorType> loopTypes =
       shardingOp.getLoopIteratorTypes();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 29320f1e339f8..65d12a16085d3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,9 +12,16 @@
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 #include <vector>
 
 namespace mlir {
@@ -30,6 +37,70 @@ namespace mesh {
 using namespace mlir;
 using namespace mlir::mesh;
 
+enum class ReshardingRquirementKind {
+  NO_RESHARDING = 0,
+  NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
+  RESHARDING_FOR_EXPLICIT_ANNOTATIONS
+};
+
+#ifdef LLVM_DEBUG
+
+template <typename T>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     const SmallVector<T> &vec);
+template <typename... Ts>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     const std::tuple<Ts...> &t);
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     ReshardingRquirementKind v);
+
+template <typename Stream, typename Range>
+static Stream &printRange(Stream &stream, Range &&range) {
+  stream << "[";
+  llvm::for_each(range, [&stream](auto &v) {
+    stream << v;
+    stream << ", ";
+  });
+  return stream << "]";
+}
+
+template <typename T>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     const SmallVector<T> &vec) {
+  return printRange(stream, vec);
+}
+
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     const ShardingOption &v) {
+  return stream << "{empty = " << v.empty << ", mesh" << v.mesh
+                << ", shardingArray = " << v.shardingArray << "}";
+}
+
+template <typename Stream, typename... Ts, size_t... Is>
+static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
+                          std::index_sequence<Is...>) {
+  static_assert(sizeof...(Is) == sizeof...(Ts),
+                "Indices must have same number of elements as tuple types!");
+  static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
+
+  stream << "{";
+  ((stream << std::get<Is>(tuple) << ", "), ...);
+  return stream << "}";
+}
+
+template <typename... Ts>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     const std::tuple<Ts...> &t) {
+  return printTuple(stream, t, std::index_sequence_for<Ts...>{});
+}
+
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+                                     ReshardingRquirementKind v) {
+  return stream << static_cast<int>(v);
+}
+
+#endif // LLVM_DEBUG
+
 //===----------------------------------------------------------------------===//
 // Utilities
 //===----------------------------------------------------------------------===//
@@ -77,6 +148,138 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
   return allShardingAttrs;
 }
 
+// From all the sharding options return the one that is most compatible with
+// the sharding annotations of operands and results of the operation.
+// The order of preference is form highest to lowest:
+// 1. No resharding is required (all existing annotations are compatible).
+// 2. No resharding for operands/results that have annotation specifically
+//   targeting this operation. This means
+//   * operands that are the result of `mesh.shard` ops marked with
+//     `annotate_for_users`.
+//   * results that are annotated with `mesh.shard` ops without
+//     `annotate_for_users`.
+// 3. All other cases. Resharding is required for operands/results with
+//   annotation targeting explicitly this operation.
+// size_t preferredShardingOption(Operation *op, const
+// SmallVector<ShardingOption>& shardingOptions) {
+
+// }
+
+ReshardingRquirementKind getReshardingRquirementKind(
+    Operation *op,
+    const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
+  ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
+
+  size_t operandsCount = op->getOperands().size();
+  auto operandShardings =
+      llvm::make_range(operandAndResultShardings.begin(),
+                       operandAndResultShardings.begin() + operandsCount);
+  auto resultShardings =
+      llvm::make_range(operandAndResultShardings.begin() + operandsCount,
+                       operandAndResultShardings.end());
+
+  for (auto [operand, sharding] :
+       llvm::zip_equal(op->getOperands(), operandShardings)) {
+    ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
+    if (!shardOp) {
+      continue;
+    }
+    bool needsResharding = shardOp.getShardAttr() != sharding;
+    bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
+    if (needsResharding) {
+      if (isExplicitAnnotationForThisOp) {
+        // This is the worst case. No need to continue.
+        return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+      }
+      res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+    }
+  }
+
+  for (auto [result, sharding] :
+       llvm::zip_equal(op->getResults(), resultShardings)) {
+    for (auto user : result.getUsers()) {
+      ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+      if (!shardOp) {
+        continue;
+      }
+      bool needsResharding = shardOp.getShardAttr() != sharding;
+      bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
+      if (needsResharding) {
+        if (isExplicitAnnotationForThisOp) {
+          // This is the worst case. No need to continue.
+          return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+        }
+        res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+      }
+    }
+  }
+
+  return res;
+}
+
+static FailureOr<ShardingOption> selectShardingOption(
+    ShardingInterface shardingOp,
+    ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
+    ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
+  SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
+      shardingOptionsAndReshardingRequirements;
+
+  for (ArrayRef<MeshShardingAttr> resultShardings :
+       possibleResultShardingAttrs) {
+    for (ArrayRef<MeshShardingAttr> operandShardings :
+         possibleOperandShardingAttrs) {
+      FailureOr<ShardingOption> shardingOption =
+          shardingOp.getShardingOption(operandShardings, resultShardings);
+      if (failed(shardingOption) || shardingOption->empty) {
+        continue;
+      }
+      // These shardings may not be the same as those in operandShardings and
+      // resultShardings.
+      // They may be missing some annotations.
+      // Whatever is returned by getShardingAnnotations is exactly what the op
+      // needs.
+      FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
+          shardingOp.getShardingAnnotations(*shardingOption);
+      if (failed(operandAndResultShardings)) {
+        return failure();
+      }
+
+      LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
+                        << *operandAndResultShardings << "\n";);
+
+      ReshardingRquirementKind reshardingRquirement =
+          getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
+      if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
+        // This is the best case. No need to go on.
+        return *shardingOption;
+      }
+
+      shardingOptionsAndReshardingRequirements.emplace_back(
+          std::move(*shardingOption), reshardingRquirement);
+    }
+  }
+
+  if (shardingOptionsAndReshardingRequirements.empty()) {
+    return ShardingOption::makeEmpty();
+  }
+
+  std::partial_sort(
+      shardingOptionsAndReshardingRequirements.begin(),
+      shardingOptionsAndReshardingRequirements.begin() + 1,
+      shardingOptionsAndReshardingRequirements.end(),
+      [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
+         const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
+        return std::get<ReshardingRquirementKind>(a) <
+               std::get<ReshardingRquirementKind>(b);
+      });
+
+  LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
+                    << shardingOptionsAndReshardingRequirements << "\n";);
+
+  return std::get<ShardingOption>(
+      shardingOptionsAndReshardingRequirements.front());
+}
+
 // For each operation that implements the ShardingInterface, infer the sharding
 // option of the operation from its operands and/or results using the
 // `getShardingOption` method. If the inferred sharding option is not empty, add
@@ -135,32 +338,21 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
       getOrderedPossibleShardingAttrs(resultMustShardings,
                                       allowConflictsResultShardings);
-  FailureOr<ShardingOption> finalShardingOption = failure();
-  for (ArrayRef<MeshShardingAttr> resultShardings :
-       possibleResultShardingAttrs) {
-    if (succeeded(finalShardingOption))
-      break;
-    for (ArrayRef<MeshShardingAttr> operandShardings :
-         possibleOperandShardingAttrs) {
-      FailureOr<ShardingOption> shardingOption =
-          shardingOp.getShardingOption(operandShardings, resultShardings);
-      if (succeeded(shardingOption)) {
-        finalShardingOption = shardingOption;
-        break;
-      }
-    }
-  }
+  FailureOr<ShardingOption> shardingOption = selectShardingOption(
+      shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
 
-  if (failed(finalShardingOption)) {
+  if (failed(shardingOption)) {
     op->emitOpError() << "fail to get sharding option.";
     return failure();
   }
+
+  LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
+
   // sharding info is empty, return immediately
-  if (finalShardingOption->empty)
+  if (shardingOption->empty)
     return success();
 
-  if (failed(
-          shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
+  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
     op->emitOpError() << "fail to set sharding annotations.";
     return failure();
   }
@@ -199,6 +391,7 @@ struct ShardingPropagation
 
     LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
                       << funcOp << "\n");
+    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
 
     // 2. propagate in original order
     for (Operation &op : llvm::make_early_inc_range(block))
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6b1326d76bc4a..f3e4b15aec118 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -493,8 +493,6 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
-  assert(!source.getAnnotateForUsers());
-  assert(target.getAnnotateForUsers());
   assert(source.getResult() == target.getOperand());
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
   return reshard(
@@ -628,7 +626,6 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
   } else {
     // Insert resharding.
-    assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
     TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
         spmdizationMap.lookup(srcShardOp.getOperand()));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
new file mode 100644
index 0000000000000..59fd548dc2ef2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt \
+// RUN:   --verify-each \
+// RUN:   --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN:   %s | FileCheck %s
+
+mesh.mesh @mesh_2_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+  %arg0 : tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+  %arg1 : tensor<3x2xf32>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+  // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
+  // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+  // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
+  // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
+  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+
+  // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 270787ab51883..11a80594adb79 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
 
+mesh.mesh @mesh_2(shape = 2)
 mesh.mesh @mesh_1d(shape = ?)
 mesh.mesh @mesh_2d(shape = 2x4)
 mesh.mesh @mesh_3d(shape = ?x?x?)
@@ -73,12 +74,11 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
   // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
   // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
   %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V8:.*]] = tosa.negate %[[V7]]
-  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]]
+  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
   %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
   %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
-  // CHECK-NEXT: return %[[V6]], %[[V9]]
+  // CHECK-NEXT: return %[[V6]], %[[V8]]
   return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
 }
 
@@ -135,6 +135,34 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
   return %2 : tensor<2x16x32xf32>
 }
 
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+  // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+  %arg0: tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+  %arg1: tensor<3x2xf32>,
+  // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+  // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+  // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x2xf32>
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+  // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
+  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
+
+  // CHECK: return %[[MATMUL_SHARDED1]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
+
 // https://arxiv.org/abs/2211.05102 Figure 2(a)
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2df247aba3515..d7a1e2fd9d279 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -16,6 +16,21 @@ func.func @full_replication(
   return %1 : tensor<2xi8>
 }
 
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+  %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+  %sharding_annotated = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xf32>
+  %sharding_annotated_0 = mesh.shard %sharding_annotated to <@mesh_1d, [[0]]> annotate_for_users : tensor<2xf32>
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d, [[]]> : tensor<2xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+  return %sharding_annotated_1 : tensor<2xf32>
+}
+
+
 // CHECK-LABEL: func @move_split_axis
 func.func @move_split_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>

>From c9d46bde07bb496a170281ece98b5369857436e7 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 22 May 2024 11:20:02 -0500
Subject: [PATCH 2/2] Remove commented code and add some comments

---
 .../Mesh/Transforms/ShardingPropagation.cpp        | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 65d12a16085d3..58a956beb41cb 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -148,8 +148,6 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
   return allShardingAttrs;
 }
 
-// From all the sharding options return the one that is most compatible with
-// the sharding annotations of operands and results of the operation.
 // The order of preference is form highest to lowest:
 // 1. No resharding is required (all existing annotations are compatible).
 // 2. No resharding for operands/results that have annotation specifically
@@ -160,11 +158,6 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
 //     `annotate_for_users`.
 // 3. All other cases. Resharding is required for operands/results with
 //   annotation targeting explicitly this operation.
-// size_t preferredShardingOption(Operation *op, const
-// SmallVector<ShardingOption>& shardingOptions) {
-
-// }
-
 ReshardingRquirementKind getReshardingRquirementKind(
     Operation *op,
     const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
@@ -217,6 +210,13 @@ ReshardingRquirementKind getReshardingRquirementKind(
   return res;
 }
 
+// From all the operand and result sharding combinations,
+// return the one that is most desirable.
+// The order of preference is:
+// 1. No resharding with respect to existing sharding annotations.
+// 2. Resharding for values that have already annotations that do not target
+//    this op.
+// 3. Resharding of existing explicit sharding annotations for this op.
 static FailureOr<ShardingOption> selectShardingOption(
     ShardingInterface shardingOp,
     ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,



More information about the Mlir-commits mailing list