[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)

Chengji Yao llvmlistbot at llvm.org
Wed Oct 25 19:43:38 PDT 2023


https://github.com/yaochengji updated https://github.com/llvm/llvm-project/pull/69665

>From 25d68a6caf9c32747ab011b6d955d914573c2f55 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Wed, 18 Oct 2023 18:54:49 +0000
Subject: [PATCH 1/9] [MLIR][Mesh] Add sharding propagation pass

---
 mlir/include/mlir/Dialect/Mesh/CMakeLists.txt |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  22 +
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |   4 +
 .../Mesh/Interfaces/ShardingInterface.h       |  58 ++
 .../Mesh/Interfaces/ShardingInterface.td      |  87 +++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   6 +
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |  41 ++
 .../mlir/Dialect/Mesh/Transforms/Passes.td    |  33 ++
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.h   |  23 +
 .../mlir/Dialect/Utils/IndexingUtils.h        |   3 +
 mlir/include/mlir/IR/AffineMap.h              |  12 +
 mlir/include/mlir/IR/Builders.h               |   1 +
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/include/mlir/InitAllPasses.h             |   2 +
 mlir/lib/Dialect/Mesh/CMakeLists.txt          |   2 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  32 +-
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |  15 +
 .../Mesh/Interfaces/ShardingInterface.cpp     | 544 ++++++++++++++++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |  17 +
 .../Mesh/Transforms/ShardingPropagation.cpp   | 155 +++++
 mlir/lib/Dialect/Tosa/CMakeLists.txt          |  14 +
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 111 ++++
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      |  13 +
 mlir/lib/IR/AffineMap.cpp                     |  13 +
 mlir/lib/IR/Builders.cpp                      |   9 +
 .../Dialect/Mesh/sharding-propagation.mlir    | 167 ++++++
 27 files changed, 1421 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
 create mode 100644 mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"Parallel", 1, "parallel">,
+  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+  I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::mesh";
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
        $partial_axes^ `]`)? `>`
   }];
 
+  let builders = [
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
+                     "ArrayRef<int32_t>": $partial_axes,
+                     "mesh::Partial": $partial_type), [{
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+        llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+          return DenseI32ArrayAttr::get($_ctxt, array);
+      }));
+      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+                   partial_type);
+    }]>,
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+    }]>
+  ];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+  for (int64_t i = array.size() - 1; i >= 0; i--) {
+    if (array[i].empty())
+      array.pop_back();
+    else
+      break;
+  }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+  // An array of int array. The sub-array at the i-th position signifies the
+  // mesh axes the i-th loop will be sharded on.
+  ShardingArray shardingArray;
+  SymbolRefAttr cluster;
+  // `empty` is true indicates that no sharding infomation can be inferred at
+  // present. Note that it is different from that an operation is not sharded.
+  bool empty = false;
+  ShardingOption() = default;
+  ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+      : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+                              const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+    let description = [{
+        Interface for allowing operations to expose information needed to
+        shard them.
+    }];
+    let cppNamespace = "::mlir::mesh";
+
+    let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of iterator types that describe the number of loops.
+        }],
+        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*methodName=*/"getLoopIteratorTypes",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the indexing maps attribute within the current operation.
+        }],
+        /*retTy=*/"SmallVector<AffineMap>",
+        /*methodName=*/"getIndexingMaps",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Given that certain operands or results of the operation may have
+          sharding annotations, this method leverages this information to deduce
+          how the operation should be sharded.
+        }],
+        /*retTy=*/"FailureOr<ShardingOption>",
+        /*methodName=*/"getShardingOption",
+        /*args=*/(ins
+          "OpBuilder &":$b
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingOption(
+            $_op.getOperation(), b);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, this method adds `mesh.shard`
+          operations for the operands and results that previously lacked
+          sharding annotations.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"addShardingAnnotations",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultAddShardingAnnotations(
+            $_op.getOperation(), b, shardingOption);
+        }]
+      >
+    ];
+
+    let extraClassDeclaration = [{
+      LogicalResult verifyShardingInterfaceImpl();
+
+      void printLoopTypesAndIndexingMaps(raw_ostream &os);
+    }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+  let summary = "sharding propagation";
+  let description = [{
+    Propagates sharding information throughout the graph. After this pass, each
+    of the operations' operands and results is annotated with a `mesh.shard`
+    operation, and the operations themselves are added with sharding option
+    attributes.
+  }];
+  let constructor = "mlir::mesh::createShardingPropagationPass()";
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
 
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
 /// Compute linear index from provided strides and indices, assuming strided
 /// layout.
 /// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
 
+  /// Returns an affine map with `numDims` input dimensions and results
+  /// specified by `targets`.
+  ///
+  /// Examples:
+  /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+  /// * getMultiDimMapWithTargets(3, [2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
+  static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+                                      ArrayRef<int64_t> targets,
+                                      MLIRContext *context);
+
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
   /// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
   ArrayAttr getTypeArrayAttr(TypeRange values);
+  ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
 
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+  return (partial == Partial::Generic &&
+          iType == IteratorType::ReductionGeneric) ||
+         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+         (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+  switch (iType) {
+  case IteratorType::ReductionGeneric:
+    return Partial::Generic;
+  case IteratorType::ReductionSum:
+    return Partial::Sum;
+  case IteratorType::ReductionMax:
+    return Partial::Max;
+  case IteratorType::ReductionMin:
+    return Partial::Min;
+  default:
+    assert(0 && "No corresponding partial type can be found");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.cluster op
 //===----------------------------------------------------------------------===//
@@ -95,7 +126,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   }
   if (failed(checkMeshAxis(partialAxes)))
     return failure();
-
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..1010756f1fe279a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRShardingInterface
+  ShardingInterface.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRShardingInterfaceIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRIR
+  MLIRMeshDialect
+  MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
new file mode 100644
index 000000000000000..7d5c73851bb1f4e
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -0,0 +1,544 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+static FailureOr<ShardingOption> getShardingOptionFromAttr(Operation *op) {
+  auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
+  if (!arrayAttr)
+    return failure();
+  auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
+  if (!symbolRefAttr)
+    return failure();
+  return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+  Value val = result.cast<Value>();
+  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return !shardOp.getAnnotateForUsers();
+  });
+
+  if (anyShardedForDef) {
+    assert(val.hasOneUse() &&
+           "expected to has exact one use if it has a use of mesh.shard "
+           "without unit attr annotate_for_users");
+    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+    return shardOp.getShard();
+  } else if (useOperandSharding) {
+    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+      if (!shardOp)
+        return false;
+      return shardOp.getAnnotateForUsers();
+    });
+    if (anyShardedForUsers) {
+      SmallVector<ShardOp> shardOps;
+      for (Operation *user : val.getUsers()) {
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+        if (shardOp)
+          shardOps.push_back(shardOp);
+      }
+      MeshShardingAttr shardForDef = shardOps[0].getShard();
+      for (size_t i = 1; i < shardOps.size(); ++i) {
+        // TODO: Deduce a reasonable mesh sharding attr for def when they are
+        // different
+        assert(shardOps[i].getShard() == shardForDef &&
+               "only support all shard ops have the same mesh sharding attr");
+      }
+      return shardForDef;
+    }
+  }
+
+  return failure();
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+static FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand) {
+  Value val = opOperand.get();
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) {
+    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+  }
+
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+                                  SmallVectorImpl<bool> &seenIds) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Add: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+      return failure();
+    if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+      return failure();
+    return success();
+  }
+  case AffineExprKind::Mul: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    AffineExpr dimExpr;
+    if (lhs.getKind() == AffineExprKind::DimId) {
+      dimExpr = lhs;
+      if (rhs.getKind() != AffineExprKind::Constant)
+        return failure();
+    } else if (rhs.getKind() == AffineExprKind::DimId &&
+               lhs.getKind() == AffineExprKind::Constant) {
+      dimExpr = rhs;
+    } else
+      return failure();
+    unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  case AffineExprKind::DimId: {
+    unsigned position = expr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  default:
+    return failure();
+  }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+  SmallVector<bool> seenIds(numDims, false);
+  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+    return failure();
+
+  llvm::SmallSet<unsigned, 2> positions;
+  for (auto it : llvm::enumerate(seenIds)) {
+    if (it.value())
+      positions.insert((unsigned)it.index());
+  }
+  return positions;
+}
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+  Operation *op = getOperation();
+
+  // check operands and results type
+  for (Type type : op->getOperandTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+  for (Type type : op->getResultTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+
+  // check loop types
+  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  if (loopTypes.size() == 0)
+    return failure();
+
+  // check maps
+  SmallVector<AffineMap> maps = getIndexingMaps();
+  if (maps.size() == 0)
+    return failure();
+  unsigned numOperands = op->getNumOperands();
+  unsigned numResults = op->getNumResults();
+  if (numOperands + numResults != maps.size())
+    return failure();
+
+  for (OpResult result : op->getResults()) {
+    auto resultType = result.getType().dyn_cast<RankedTensorType>();
+    if (!resultType)
+      return failure();
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    if (!map.isProjectedPermutation()) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::printLoopTypesAndIndexingMaps
+//===----------------------------------------------------------------------===//
+
+void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+  os << "print loop types and indexing maps for: \n";
+  getOperation()->print(os);
+  os << "\n";
+  os << "loop types: [";
+  for (IteratorType type : getLoopIteratorTypes()) {
+    os << stringifyEnum(type) << " ";
+  }
+  os << "]\n";
+  os << "indexing maps: \n";
+  for (AffineMap map : getIndexingMaps())
+    os << map << "\n";
+  os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultGetShardingOption
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+static LogicalResult
+fillShardingOption(Operation *op, ShardingOption &shardingOption,
+                   SymbolRefAttr cluster, ArrayRef<int32_t> meshAxes,
+                   unsigned loopIdx, bool ignoreIfConflicted = false) {
+  if ((shardingOption.cluster && cluster &&
+       shardingOption.cluster != cluster) ||
+      (!shardingOption.shardingArray[loopIdx].empty() &&
+       shardingOption.shardingArray[loopIdx] != meshAxes)) {
+    if (ignoreIfConflicted)
+      return success();
+    else
+      return op->emitOpError()
+             << "sharding option confilicts on loop iterator " << loopIdx;
+  }
+  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
+    if (i != loopIdx) {
+      for (int32_t axis : meshAxes)
+        if (std::find(shardingOption.shardingArray[i].begin(),
+                      shardingOption.shardingArray[i].end(),
+                      axis) != shardingOption.shardingArray[i].end()) {
+          if (ignoreIfConflicted)
+            return success();
+          else
+            return op->emitOpError()
+                   << "sharding option confilicts because of mesh axis " << axis
+                   << " duplicates";
+        }
+    }
+  }
+  if (cluster)
+    shardingOption.cluster = cluster;
+  if (shardingOption.shardingArray[loopIdx].empty())
+    shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
+                                                 meshAxes.end());
+  return success();
+}
+
+} // namespace
+
+FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(Operation *op,
+                                                                 OpBuilder &b) {
+
+  // 1. If a valid sharding attribute exists, use it.
+  FailureOr<ShardingOption> shardingOptionFromAttr =
+      getShardingOptionFromAttr(op);
+  if (succeeded(shardingOptionFromAttr))
+    return shardingOptionFromAttr;
+
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  ShardingOption shardingOption;
+
+  if (failed(shardingOp.verifyShardingInterfaceImpl()))
+    return op->emitOpError() << "invalid sharding interface implementation";
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+  shardingOption.shardingArray.resize(loopTypes.size());
+  llvm::SmallVector<int32_t> partialMeshAxes;
+  Partial partialType;
+  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
+  bool anyShardingInResultsOrOperands = false;
+
+  // 2. Fill sharding option based on op results
+  for (OpResult result : op->getResults()) {
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
+    if (failed(shardAttr))
+      continue;
+    anyShardingInResultsOrOperands = true;
+    // Handle the split axes: calculate the corresponding loop index for each
+    // split axes sub-array, and then store the sub-array to
+    // shardingOption[index]
+    for (auto it : llvm::zip(map.getResults(), shardAttr->getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      auto dim = expr.cast<AffineDimExpr>();
+      unsigned index = dim.getPosition();
+      visitedLoopIndices.insert(index);
+      if (failed(fillShardingOption(op, shardingOption, shardAttr->getCluster(),
+                                    axes, index)))
+        return failure();
+    }
+
+    // Handle the partial axes: at this stage, the exact loop index/indices
+    // cannot be decided because there could be multiple reduction loops.
+    ArrayRef<int32_t> partialAxes = shardAttr->getPartialAxes();
+    if (!partialAxes.empty()) {
+      if (!partialMeshAxes.empty())
+        return op->emitOpError() << "at most one result with partial axes is "
+                                    "supported at present";
+      partialType = shardAttr->getPartialType();
+      partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
+      // Add all the reduction loop indices to `visitedLoopIndices` if
+      // `partialAxes` is not empty
+      for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
+        if (isReductionLoop(loopTypes[loopIdx]))
+          visitedLoopIndices.insert(loopIdx);
+      }
+    }
+  }
+
+  // 3. Fill sharding option based on operands
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(opOperand);
+    if (failed(maybeShardAttr))
+      continue;
+
+    anyShardingInResultsOrOperands = true;
+    bool annotateForUsers = maybeShardAttr->first;
+    MeshShardingAttr shardAttr = maybeShardAttr->second;
+    AffineMap map = maps[opOperand.getOperandNumber()];
+    unsigned numDims = map.getNumDims();
+
+    // Handle the split axes, and partial axes don't need to be handled because
+    // they only affect the definig op of the operand
+    //
+    // TODO: Change to process the operands with single loop index first and
+    // then the operands with multiple loop indices
+    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+          checkOperandAffineExpr(expr, numDims);
+      if (failed(loopIndices))
+        return op->emitOpError()
+               << "operand's affine expression is restricted to const_i * "
+                  "dim_i + const_j + dim_j + ...";
+      if (loopIndices->empty())
+        continue;
+      if (loopIndices->size() == 1) {
+        unsigned loopIdx = *loopIndices->begin();
+        visitedLoopIndices.insert(loopIdx);
+        if (failed(fillShardingOption(op, shardingOption,
+                                      shardAttr.getCluster(), axes, loopIdx,
+                                      !annotateForUsers)))
+          return failure();
+      }
+      // If multiple loop indices correspond to a dimension of an operand, it is
+      // difficult to infer which loop indices are responsible for sharding.
+      // Therefore, the exact loop index must be specified by others.
+      if (loopIndices->size() > 1) {
+        bool seenLoopIndices = false;
+        for (unsigned loopIdx : *loopIndices) {
+          if (visitedLoopIndices.contains(loopIdx)) {
+            seenLoopIndices = true;
+            break;
+          }
+        }
+        if (!seenLoopIndices)
+          return op->emitOpError()
+                 << "the operand " << opOperand.getOperandNumber()
+                 << " has multiple loop indices in a dimension, but none of "
+                    "them could be found in the exactly specified annotation "
+                    "of op results or operands.";
+      }
+    }
+  }
+
+  // 4. Finalize sharding option
+  if (!partialMeshAxes.empty()) {
+    bool anyNonEmptyReductionLoop = llvm::any_of(
+        llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
+          SmallVector<int32_t> &subArray = it.value();
+          int64_t idx = it.index();
+          return isReductionLoop(loopTypes[idx]) && !subArray.empty();
+        });
+    if (!anyNonEmptyReductionLoop) {
+      bool filled = false;
+      for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
+        if (isReductionLoop(loopTypes[idx]) &&
+            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+          std::ignore = fillShardingOption(op, shardingOption, nullptr,
+                                           partialMeshAxes, idx);
+          filled = true;
+          break;
+        }
+      }
+      if (!filled)
+        return op->emitOpError() << "no matched reduction loop found for the "
+                                    "result's partial type";
+    }
+  }
+  removeTrailingEmptySubArray(shardingOption.shardingArray);
+  if (!anyShardingInResultsOrOperands)
+    shardingOption.empty = true;
+  return shardingOption;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  if (succeeded(getMeshShardingAttr(result, false)))
+    return success();
+
+  auto resultType = result.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
+  SmallVector<int32_t> partialAxes;
+
+  // process the split axes
+  for (auto it : llvm::enumerate(map.getResults())) {
+    AffineExpr expr = it.value();
+    auto dim = expr.cast<AffineDimExpr>();
+    unsigned loopIdx = dim.getPosition();
+    if (loopIdx < shardingOption.shardingArray.size())
+      splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
+  }
+
+  // process the partial axes
+  Partial partialType;
+  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
+    IteratorType iType = std::get<0>(it);
+    if (isReductionLoop(iType)) {
+      Partial curPartialType = getPartialTypeFromReduction(iType);
+      if (!partialAxes.empty())
+        assert(partialType == curPartialType &&
+               "Only one reduction type is supported");
+      partialType = curPartialType;
+      const SmallVector<int32_t> &axis = std::get<1>(it);
+      partialAxes.append(axis);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
+                            partialAxes, partialType);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointAfterValue(result);
+  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
+                                   shardAttr, /*annotate_for_users*/ false);
+  result.replaceAllUsesExcept(shardOp, shardOp);
+  return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
+  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
+    return success();
+  Value operand = opOperand.get();
+  auto operandType = operand.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+  unsigned numDims = map.getNumDims();
+  for (auto it : llvm::enumerate(map.getResults())) {
+    int64_t idx = it.index();
+    AffineExpr expr = it.value();
+    FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+        checkOperandAffineExpr(expr, numDims);
+    if (failed(loopIndices))
+      return failure();
+    SmallVector<unsigned> shardedLoopIndices;
+    for (unsigned loopIdx : *loopIndices) {
+      if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
+          !shardingOption.shardingArray[loopIdx].empty())
+        shardedLoopIndices.push_back(loopIdx);
+    }
+    // mostly one sharded loop index is accepted
+    if (shardedLoopIndices.size() > 1)
+      return failure();
+    if (shardedLoopIndices.size() == 1) {
+      splitAxes[idx].append(
+          shardingOption.shardingArray[shardedLoopIndices[0]]);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(opOperand.getOwner());
+  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
+                                   shardAttr, true);
+  opOperand.set(shardOp);
+
+  return success();
+}
+
+} // namespace
+
+LogicalResult mesh::detail::defaultAddShardingAnnotations(
+    Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+
+  // 1. add mesh.shard ops for all op results
+  for (OpResult result : op->getResults()) {
+    if (failed(addShardOp(b, result, shardingOption,
+                          maps[numOperands + result.getResultNumber()],
+                          loopTypes)))
+      return failure();
+  }
+
+  // 2. add mesh.shard ops for all operands
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    if (failed(addShardOp(b, opOperand, shardingOption,
+                          maps[opOperand.getOperandNumber()], loopTypes)))
+      return failure();
+  }
+
+  return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..9f85d8e9cb22d5b
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRMeshTransforms
+  ShardingPropagation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRMeshPassIncGen
+  MLIRShardingInterface
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMeshDialect
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaShardingInterfaceImpl
+)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
new file mode 100644
index 000000000000000..a3f305444afd835
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -0,0 +1,155 @@
+//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+#include <vector>
+
+namespace mlir {
+namespace mesh {
+#define GEN_PASS_DEF_SHARDINGPROPAGATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+} // namespace mesh
+} // namespace mlir
+
+#define DEBUG_TYPE "sharding-propagation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static std::vector<Operation *> getOperationsVector(Block &block) {
+  std::vector<Operation *> res;
+  for (auto it = block.begin(); it != block.end(); ++it) {
+    Operation *op = &*it;
+    res.push_back(op);
+  }
+  return res;
+}
+
+static std::vector<Operation *> getReversedOperationsVector(Block &block) {
+  std::vector<Operation *> res;
+  for (auto it = block.rbegin(); it != block.rend(); ++it) {
+    Operation *op = &*it;
+    res.push_back(op);
+  }
+  return res;
+}
+
+// For each operation that implements the ShardingInterface, infer the sharding
+// option of the operation from its operands and/or results using the
+// `getShardingOption` method. If the inferred sharding option is not empty, add
+// a `mesh.shard` operation for all remaining operands and results that do not
+// have sharding annotations.
+LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+    return success();
+
+  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
+  if (!shardingOp) {
+    op->emitOpError() << "sharding interface is not implemented.";
+    return failure();
+  }
+
+  FailureOr<ShardingOption> shardingOption =
+      shardingOp.getShardingOption(builder);
+  if (failed(shardingOption)) {
+    op->emitOpError() << "fail to get sharding option from results.";
+    return failure();
+  }
+  // sharding info is empty, return immediately
+  if (shardingOption->empty)
+    return success();
+
+  ArrayAttr shardingArrayAttr =
+      builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray);
+  LLVM_DEBUG(DBGS() << "mesh cluster: " << shardingOption->cluster << "\n");
+  LLVM_DEBUG(DBGS() << "sharding array: " << shardingArrayAttr << "\n");
+  op->setAttr(getMeshClusterName(), shardingOption->cluster);
+  op->setAttr(getShardingArrayName(),
+              builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray));
+
+  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
+    op->emitOpError() << "fail to set sharding annotations.";
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagationPass
+//===----------------------------------------------------------------------===//
+struct ShardingPropagationPass
+    : public mesh::impl::ShardingPropagationBase<ShardingPropagationPass> {
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    MLIRContext *ctx = funcOp.getContext();
+    Region &region = funcOp.getBody();
+    OpBuilder builder(ctx);
+    if (!region.hasOneBlock()) {
+      funcOp.emitOpError() << "only one block is supported!";
+      signalPassFailure();
+    }
+    Block &block = region.front();
+
+    // clang-format off
+    LLVM_DEBUG(
+      DBGS() << "print all the ops' iterator types and indexing maps in the "
+                "block.\n";
+      DenseSet<ShardingInterface> ops;
+      for (Operation &op : block.getOperations()) {
+        if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) {
+          ops.insert(shardingOp);
+        }
+      }
+      for (ShardingInterface shardingOp : ops) {
+        shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
+      }
+    );
+    // clang-format on
+
+    // 1. propagate in reversed order
+    {
+      std::vector<Operation *> curOps = getReversedOperationsVector(block);
+      for (Operation *op : curOps) {
+        if (failed(visitOp(op, builder)))
+          return signalPassFailure();
+      }
+    }
+
+    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+                      << funcOp << "\n");
+
+    // 2. propagate in original order
+    {
+      std::vector<Operation *> curOps = getOperationsVector(block);
+      for (Operation *op : curOps) {
+        if (failed(visitOp(op, builder)))
+          return signalPassFailure();
+      }
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::mesh::createShardingPropagationPass() {
+  return std::make_unique<ShardingPropagationPass>();
+}
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index 8e32579e0e4c2e3..ba5343dcd7ac6c1 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -26,4 +26,18 @@ add_mlir_dialect_library(MLIRTosaDialect
   MLIRViewLikeInterface
   )
 
+add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
+  IR/ShardingInterfaceImpl.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMeshDialect
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaDialect
+  )
+
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..dace86533c0e231
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -0,0 +1,111 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tosa-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tosa;
+using namespace mlir::mesh;
+
+namespace {
+
+template <typename ElemwiseOp>
+struct ElemwiseSharding
+    : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
+                                              ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+};
+
+// loop types: [parallel, parallel, parallel, reduction_sum]
+// indexing maps:
+// (d0, d1, d2, d3) -> (d0, d1, d3)
+// (d0, d1, d2, d3) -> (d0, d3, d2)
+// (d0, d1, d2, d3) -> (d0, d1, d2)
+struct MatMulOpSharding
+    : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+
+    SmallVector<IteratorType> types(tensorType.getRank() + 1,
+                                    IteratorType::Parallel);
+    types[tensorType.getRank()] = IteratorType::ReductionSum;
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+    MLIRContext *ctx = op->getContext();
+    SmallVector<AffineMap> maps;
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
+    return maps;
+  }
+};
+
+template <typename OpType>
+static void registerElemwiseOne(MLIRContext *ctx) {
+  OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerElemwiseAll(MLIRContext *ctx) {
+  (registerElemwiseOne<OpTypes>(ctx), ...);
+}
+
+} // namespace
+
+void mlir::tosa::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
+    registerElemwiseAll<
+        ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
+        BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
+        LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
+        MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
+        LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+        GreaterOp, GreaterEqualOp>(ctx);
+
+    MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index f4e29539214b4b6..b247b8cc694eca9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -259,6 +259,19 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
+SmallVector<SmallVector<int32_t>>
+mlir::getArrayOfI32Array(ArrayAttr arrayAttr) {
+  SmallVector<SmallVector<int32_t>> arrayOfI32Array;
+  for (auto attr : arrayAttr) {
+    arrayOfI32Array.push_back(llvm::to_vector(
+        llvm::map_range(llvm::cast<ArrayAttr>(attr), [&](Attribute intAttr) {
+          return static_cast<int32_t>(
+              llvm::cast<IntegerAttr>(intAttr).getInt());
+        })));
+  }
+  return arrayOfI32Array;
+}
+
 // TODO: do we have any common utily for this?
 static MLIRContext *getContext(OpFoldResult val) {
   assert(val && "Invalid value");
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9cdac964710ca86..db5087ce42809f8 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -235,6 +235,19 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   return permutationMap;
 }
 
+AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
+                                               ArrayRef<int64_t> targets,
+                                               MLIRContext *context) {
+  AffineMap result =
+      AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, context);
+  int64_t pos = 0;
+  for (int64_t t : targets) {
+    result = result.insertResult(getAffineDimExpr(t, context), pos);
+    pos += 1;
+  }
+  return result;
+}
+
 template <typename AffineExprContainer>
 static SmallVector<AffineMap, 4>
 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index ab20f4863e11c23..b86ee432a53af04 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -316,6 +316,15 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
   return getArrayAttr(attrs);
 }
 
+ArrayAttr
+Builder::getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values) {
+  auto attrs =
+      llvm::map_to_vector<8>(values, [this](ArrayRef<int32_t> v) -> Attribute {
+        return getI32ArrayAttr(v);
+      });
+  return getArrayAttr(attrs);
+}
+
 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
   auto attrs = llvm::map_to_vector<8>(
       values, [](Type v) -> Attribute { return TypeAttr::get(v); });
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
new file mode 100644
index 000000000000000..4c0809dc5e58636
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1)
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_3d(rank = 3)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT: tosa.sigmoid
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT: return
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
+  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V8:.*]] = tosa.negate %[[V7]]
+  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[V6]], %[[V9]]
+  return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %2 : tensor<2x16x32xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [0 : i32]]}
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [0 : i32]]}
+  %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [], [0 : i32]]}
+  %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  return %5 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [1 : i32, 2 : i32], [0 : i32]]}
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [1 : i32, 2 : i32]]}
+  %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [0 : i32], [1 : i32, 2 : i32]]}
+  %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+  %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  return %6 : tensor<2x4x8xf32>
+}

>From bd81c39b3349133fff9a81d068f3058b4360e284 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Fri, 20 Oct 2023 00:57:03 +0000
Subject: [PATCH 2/9] format code

---
 mlir/include/mlir/IR/AffineMap.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 18e2313ef2b446b..919ca35886c6c8b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -110,8 +110,8 @@ class AffineMap {
   /// * getMultiDimMapWithTargets(3, [2, 1])
   ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
   static AffineMap getMultiDimMapWithTargets(unsigned numDims,
-                                      ArrayRef<int64_t> targets,
-                                      MLIRContext *context);
+                                             ArrayRef<int64_t> targets,
+                                             MLIRContext *context);
 
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many

>From bf1dfddbc1498db6cf3ada7ecad932b9ce162816 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Fri, 20 Oct 2023 18:19:14 +0000
Subject: [PATCH 3/9] fix comments, 1st

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  6 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  8 +--
 .../Mesh/Interfaces/ShardingInterface.h       | 10 +--
 .../Mesh/Interfaces/ShardingInterface.td      | 25 ++++++--
 .../Mesh/Interfaces/ShardingInterface.cpp     | 63 +++++++++++++------
 .../Mesh/Transforms/ShardingPropagation.cpp   | 19 +-----
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      | 16 +++--
 mlir/lib/IR/Builders.cpp                      |  2 +-
 .../Dialect/Mesh/sharding-propagation.mlir    | 12 ++--
 9 files changed, 95 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index b6623ed818f0770..a91ef569347bff1 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -143,10 +143,10 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
                      "ArrayRef<SmallVector<int32_t>>":$split_axes,
                      "ArrayRef<int32_t>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
-      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
-        llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
+                  split_axes, [&](ArrayRef<int32_t> array) {
           return DenseI32ArrayAttr::get($_ctxt, array);
-      }));
+      });
       return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index cb86887091330c8..05eba66a89949b6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -33,12 +33,8 @@ bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
 
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
-  for (int64_t i = array.size() - 1; i >= 0; i--) {
-    if (array[i].empty())
-      array.pop_back();
-    else
-      break;
-  }
+  while (!array.empty() && array.back().empty())
+    array.pop_back();
 }
 
 Partial getPartialTypeFromReduction(IteratorType iType);
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 1d19e41ac1fc555..829f33aae440744 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -30,17 +30,13 @@ struct ShardingOption {
   // present. Note that it is different from that an operation is not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
-      : shardingArray(shardingArray), cluster(cluster) {}
+  ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+      : shardingArray(std::move(shardingArray)), cluster(cluster) {}
 };
 
-constexpr StringRef getShardingArrayName() { return "sharding_array"; }
-
-constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
-
 namespace detail {
 
-FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op);
 
 LogicalResult
 defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index c98b9f081492997..0da310c4442b4c5 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -22,6 +22,17 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Returns a list of iterator types that describe the number of loops.
+          The iterator types determine how the operation tranverses its input
+          and output tensors.
+
+          Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+          types are paralle, parallel, reduction-sum. This indicates that M and
+          N are traversed in parallel, while the K dimension is used for
+          reduction.
+
+          Example 2: A softmax op's loop iterator types are parallel and
+          invalid. The second dimension is considered as invalid because it is
+          neigher parallel nor any kind of reduction. 
         }],
         /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
         /*methodName=*/"getLoopIteratorTypes",
@@ -32,6 +43,9 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Return the indexing maps attribute within the current operation.
+          Indexing maps determine how indices in the iteration space map to
+          tensor indices. They are specified using `affine_map` in MLIR, which
+          provides an affine transformation of indices.
         }],
         /*retTy=*/"SmallVector<AffineMap>",
         /*methodName=*/"getIndexingMaps",
@@ -47,13 +61,11 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         }],
         /*retTy=*/"FailureOr<ShardingOption>",
         /*methodName=*/"getShardingOption",
-        /*args=*/(ins
-          "OpBuilder &":$b
-        ),
+        /*args=*/(ins),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return detail::defaultGetShardingOption(
-            $_op.getOperation(), b);
+            $_op.getOperation());
         }]
       >,
       InterfaceMethod<
@@ -80,6 +92,11 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       LogicalResult verifyShardingInterfaceImpl();
 
       void printLoopTypesAndIndexingMaps(raw_ostream &os);
+
+      FailureOr<ShardingOption> getShardingOptionFromAttr();
+
+      void setShardingOptionAttr(Builder &b, const ShardingOption& option);
+
     }];
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 7d5c73851bb1f4e..2fcf418032317db 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -29,16 +29,6 @@ using namespace mlir::mesh;
 // common util functions
 //===----------------------------------------------------------------------===//
 
-static FailureOr<ShardingOption> getShardingOptionFromAttr(Operation *op) {
-  auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
-  if (!arrayAttr)
-    return failure();
-  auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
-  if (!symbolRefAttr)
-    return failure();
-  return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
-}
-
 // This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
 // for a given operation result.
 static FailureOr<MeshShardingAttr>
@@ -97,10 +87,6 @@ getMeshShardingAttr(OpOperand &opOperand) {
   return failure();
 }
 
-//===----------------------------------------------------------------------===//
-// ShardingInterface::verifyShardingInterfaceImpl
-//===----------------------------------------------------------------------===//
-
 static LogicalResult
 checkOperandAffineExprRecursively(AffineExpr expr,
                                   SmallVectorImpl<bool> &seenIds) {
@@ -161,6 +147,10 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
   return positions;
 }
 
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
   Operation *op = getOperation();
 
@@ -218,6 +208,43 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   os << "\n";
 }
 
+//===----------------------------------------------------------------------===//
+// ShardingInterface::getShardingOptionFromAttr
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+} // namespace
+
+FailureOr<ShardingOption> mesh::ShardingInterface::getShardingOptionFromAttr() {
+  Operation *op = getOperation();
+  auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
+  if (!arrayAttr)
+    return failure();
+  auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
+  if (!symbolRefAttr)
+    return failure();
+  return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::setShardingOptionAttr
+//===----------------------------------------------------------------------===//
+
+void mesh::ShardingInterface::setShardingOptionAttr(
+    Builder &b, const ShardingOption &option) {
+  if (option.empty)
+    return;
+  Operation *op = getOperation();
+  ArrayAttr shardingArrayAttr = b.getArrayOfI32ArrayAttr(option.shardingArray);
+  op->setDiscardableAttr(getMeshClusterName(), option.cluster);
+  op->setDiscardableAttr(getShardingArrayName(), shardingArrayAttr);
+}
+
 //===----------------------------------------------------------------------===//
 // detail::defaultGetShardingOption
 //===----------------------------------------------------------------------===//
@@ -264,16 +291,16 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
 
 } // namespace
 
-FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(Operation *op,
-                                                                 OpBuilder &b) {
+FailureOr<ShardingOption>
+mesh::detail::defaultGetShardingOption(Operation *op) {
 
   // 1. If a valid sharding attribute exists, use it.
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   FailureOr<ShardingOption> shardingOptionFromAttr =
-      getShardingOptionFromAttr(op);
+      shardingOp.getShardingOptionFromAttr();
   if (succeeded(shardingOptionFromAttr))
     return shardingOptionFromAttr;
 
-  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   ShardingOption shardingOption;
 
   if (failed(shardingOp.verifyShardingInterfaceImpl()))
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index a3f305444afd835..76aba79250edbbf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -67,8 +67,7 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     return failure();
   }
 
-  FailureOr<ShardingOption> shardingOption =
-      shardingOp.getShardingOption(builder);
+  FailureOr<ShardingOption> shardingOption = shardingOp.getShardingOption();
   if (failed(shardingOption)) {
     op->emitOpError() << "fail to get sharding option from results.";
     return failure();
@@ -76,14 +75,7 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   // sharding info is empty, return immediately
   if (shardingOption->empty)
     return success();
-
-  ArrayAttr shardingArrayAttr =
-      builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray);
-  LLVM_DEBUG(DBGS() << "mesh cluster: " << shardingOption->cluster << "\n");
-  LLVM_DEBUG(DBGS() << "sharding array: " << shardingArrayAttr << "\n");
-  op->setAttr(getMeshClusterName(), shardingOption->cluster);
-  op->setAttr(getShardingArrayName(),
-              builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray));
+  shardingOp.setShardingOptionAttr(builder, *shardingOption);
 
   if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
     op->emitOpError() << "fail to set sharding annotations.";
@@ -112,13 +104,8 @@ struct ShardingPropagationPass
     LLVM_DEBUG(
       DBGS() << "print all the ops' iterator types and indexing maps in the "
                 "block.\n";
-      DenseSet<ShardingInterface> ops;
       for (Operation &op : block.getOperations()) {
-        if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) {
-          ops.insert(shardingOp);
-        }
-      }
-      for (ShardingInterface shardingOp : ops) {
+        if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
         shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
       }
     );
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index b247b8cc694eca9..e7f0778ec214f38 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -263,12 +263,18 @@ SmallVector<SmallVector<int32_t>>
 mlir::getArrayOfI32Array(ArrayAttr arrayAttr) {
   SmallVector<SmallVector<int32_t>> arrayOfI32Array;
   for (auto attr : arrayAttr) {
-    arrayOfI32Array.push_back(llvm::to_vector(
-        llvm::map_range(llvm::cast<ArrayAttr>(attr), [&](Attribute intAttr) {
-          return static_cast<int32_t>(
-              llvm::cast<IntegerAttr>(intAttr).getInt());
-        })));
+    if (auto denseI32ArrayAttr = attr.dyn_cast<DenseI32ArrayAttr>()) {
+      arrayOfI32Array.push_back(
+          llvm::to_vector(denseI32ArrayAttr.asArrayRef()));
+    } else {
+      arrayOfI32Array.push_back(llvm::map_to_vector(
+          llvm::cast<ArrayAttr>(attr), [](Attribute intAttr) {
+            return static_cast<int32_t>(
+                llvm::cast<IntegerAttr>(intAttr).getInt());
+          }));
+    }
   }
+
   return arrayOfI32Array;
 }
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index b86ee432a53af04..5a67d20be6fbaec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -320,7 +320,7 @@ ArrayAttr
 Builder::getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values) {
   auto attrs =
       llvm::map_to_vector<8>(values, [this](ArrayRef<int32_t> v) -> Attribute {
-        return getI32ArrayAttr(v);
+        return getDenseI32ArrayAttr(v);
       });
   return getArrayAttr(attrs);
 }
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 4c0809dc5e58636..a81fb3a8db23cfa 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -139,11 +139,11 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
 func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
   %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [0 : i32]]}
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32: 0>]}
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [0 : i32]]}
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32: 0>]}
   %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [], [0 : i32]]}
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32>, array<i32: 0>]}
   %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
   %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
   %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
@@ -154,12 +154,12 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
 // CHECK-LABEL: func.func @mlp_2d_weight_stationary
 func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
   %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [1 : i32, 2 : i32], [0 : i32]]}
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 1, 2>, array<i32: 0>]}
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
   %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
-  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [1 : i32, 2 : i32]]}
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 1, 2>]}
   %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [0 : i32], [1 : i32, 2 : i32]]}
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 0>, array<i32: 1, 2>]}
   %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
   %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
   %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>

>From 95ed43c75a319822e294578ee28e68b5aa8e988d Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Sat, 21 Oct 2023 19:17:43 +0000
Subject: [PATCH 4/9] fix comments, 2nd

---
 .../Mesh/Interfaces/ShardingInterface.h       |  5 +-
 .../Mesh/Interfaces/ShardingInterface.td      |  8 +--
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |  2 -
 .../mlir/Dialect/Mesh/Transforms/Passes.td    |  1 -
 .../Mesh/Interfaces/ShardingInterface.cpp     | 42 ++++++++--------
 .../Mesh/Transforms/ShardingPropagation.cpp   | 49 ++++---------------
 6 files changed, 38 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 829f33aae440744..e26d3a395280e7d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -26,8 +26,9 @@ struct ShardingOption {
   // mesh axes the i-th loop will be sharded on.
   ShardingArray shardingArray;
   SymbolRefAttr cluster;
-  // `empty` is true indicates that no sharding infomation can be inferred at
-  // present. Note that it is different from that an operation is not sharded.
+  // `empty` being true indicates that no sharding information can be inferred
+  // at present. Note that it is different from the case where an operation is
+  // not sharded.
   bool empty = false;
   ShardingOption() = default;
   ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 0da310c4442b4c5..dbab90a98538c93 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -22,17 +22,17 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Returns a list of iterator types that describe the number of loops.
-          The iterator types determine how the operation tranverses its input
-          and output tensors.
+          The iterator types determine how the operation traverses its input and
+          output tensors.
 
           Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
-          types are paralle, parallel, reduction-sum. This indicates that M and
+          types are parallel, parallel, reduction-sum. This indicates that M and
           N are traversed in parallel, while the K dimension is used for
           reduction.
 
           Example 2: A softmax op's loop iterator types are parallel and
           invalid. The second dimension is considered as invalid because it is
-          neigher parallel nor any kind of reduction. 
+          neither parallel nor any kind of reduction. 
         }],
         /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
         /*methodName=*/"getLoopIteratorTypes",
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index aa3555f7f186f24..83399d10beaae48 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -26,8 +26,6 @@ namespace mesh {
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
 
-std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index d36adfe476a72ac..c09cf3e710d4278 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,7 +24,6 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
-  let constructor = "mlir::mesh::createShardingPropagationPass()";
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 2fcf418032317db..e2fb1a95afb0241 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -42,9 +42,10 @@ getMeshShardingAttr(OpResult result, bool useOperandSharding) {
   });
 
   if (anyShardedForDef) {
-    assert(val.hasOneUse() &&
-           "expected to has exact one use if it has a use of mesh.shard "
-           "without unit attr annotate_for_users");
+    // expected to have exact one use if it has a use of `mesh.shard` without
+    // unit attr annotate_for_users
+    if (!val.hasOneUse())
+      return failure();
     auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
     return shardOp.getShard();
   } else if (useOperandSharding) {
@@ -80,9 +81,8 @@ getMeshShardingAttr(OpResult result, bool useOperandSharding) {
 static FailureOr<std::pair<bool, MeshShardingAttr>>
 getMeshShardingAttr(OpOperand &opOperand) {
   Value val = opOperand.get();
-  if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) {
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
     return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
-  }
 
   return failure();
 }
@@ -264,21 +264,23 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
       return success();
     else
       return op->emitOpError()
-             << "sharding option confilicts on loop iterator " << loopIdx;
+             << "sharding option conflicts on loop iterator " << loopIdx;
   }
   for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
-    if (i != loopIdx) {
-      for (int32_t axis : meshAxes)
-        if (std::find(shardingOption.shardingArray[i].begin(),
-                      shardingOption.shardingArray[i].end(),
-                      axis) != shardingOption.shardingArray[i].end()) {
-          if (ignoreIfConflicted)
-            return success();
-          else
-            return op->emitOpError()
-                   << "sharding option confilicts because of mesh axis " << axis
-                   << " duplicates";
-        }
+    if (i == loopIdx)
+      continue;
+
+    for (int32_t axis : meshAxes) {
+      if (std::find(shardingOption.shardingArray[i].begin(),
+                    shardingOption.shardingArray[i].end(),
+                    axis) != shardingOption.shardingArray[i].end()) {
+        if (ignoreIfConflicted)
+          return success();
+        else
+          return op->emitOpError()
+                 << "sharding option conflicts because mesh axes " << axis
+                 << " duplicate";
+      }
     }
   }
   if (cluster)
@@ -367,7 +369,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
     unsigned numDims = map.getNumDims();
 
     // Handle the split axes, and partial axes don't need to be handled because
-    // they only affect the definig op of the operand
+    // they only affect the defining op of the operand
     //
     // TODO: Change to process the operands with single loop index first and
     // then the operands with multiple loop indices
@@ -568,4 +570,4 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
   }
 
   return success();
-}
\ No newline at end of file
+}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 76aba79250edbbf..03bc3953741bccf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -34,24 +34,6 @@ namespace {
 // Utilities
 //===----------------------------------------------------------------------===//
 
-static std::vector<Operation *> getOperationsVector(Block &block) {
-  std::vector<Operation *> res;
-  for (auto it = block.begin(); it != block.end(); ++it) {
-    Operation *op = &*it;
-    res.push_back(op);
-  }
-  return res;
-}
-
-static std::vector<Operation *> getReversedOperationsVector(Block &block) {
-  std::vector<Operation *> res;
-  for (auto it = block.rbegin(); it != block.rend(); ++it) {
-    Operation *op = &*it;
-    res.push_back(op);
-  }
-  return res;
-}
-
 // For each operation that implements the ShardingInterface, infer the sharding
 // option of the operation from its operands and/or results using the
 // `getShardingOption` method. If the inferred sharding option is not empty, add
@@ -85,10 +67,10 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 }
 
 //===----------------------------------------------------------------------===//
-// ShardingPropagationPass
+// ShardingPropagation
 //===----------------------------------------------------------------------===//
-struct ShardingPropagationPass
-    : public mesh::impl::ShardingPropagationBase<ShardingPropagationPass> {
+struct ShardingPropagation
+    : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
   void runOnOperation() override {
     func::FuncOp funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -112,31 +94,18 @@ struct ShardingPropagationPass
     // clang-format on
 
     // 1. propagate in reversed order
-    {
-      std::vector<Operation *> curOps = getReversedOperationsVector(block);
-      for (Operation *op : curOps) {
-        if (failed(visitOp(op, builder)))
-          return signalPassFailure();
-      }
-    }
+    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
 
     LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
                       << funcOp << "\n");
 
     // 2. propagate in original order
-    {
-      std::vector<Operation *> curOps = getOperationsVector(block);
-      for (Operation *op : curOps) {
-        if (failed(visitOp(op, builder)))
-          return signalPassFailure();
-      }
-    }
+    for (Operation &op : llvm::make_early_inc_range(block))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
   }
 };
 
 } // namespace
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::mesh::createShardingPropagationPass() {
-  return std::make_unique<ShardingPropagationPass>();
-}

>From f4290d5ea7999ff634c7a6904b7888152996df8e Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Sat, 21 Oct 2023 21:06:04 +0000
Subject: [PATCH 5/9] remove sharding option attr

---
 .../Mesh/Interfaces/ShardingInterface.td      |  5 --
 .../Mesh/Interfaces/ShardingInterface.cpp     | 50 ++-----------------
 .../Mesh/Transforms/ShardingPropagation.cpp   |  1 -
 .../Dialect/Mesh/sharding-propagation.mlir    | 33 +++++++++---
 4 files changed, 30 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index dbab90a98538c93..a76fcce258dabb5 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -92,11 +92,6 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       LogicalResult verifyShardingInterfaceImpl();
 
       void printLoopTypesAndIndexingMaps(raw_ostream &os);
-
-      FailureOr<ShardingOption> getShardingOptionFromAttr();
-
-      void setShardingOptionAttr(Builder &b, const ShardingOption& option);
-
     }];
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index e2fb1a95afb0241..45039f6cb10fe87 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -208,43 +208,6 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   os << "\n";
 }
 
-//===----------------------------------------------------------------------===//
-// ShardingInterface::getShardingOptionFromAttr
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-constexpr StringRef getShardingArrayName() { return "sharding_array"; }
-
-constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
-
-} // namespace
-
-FailureOr<ShardingOption> mesh::ShardingInterface::getShardingOptionFromAttr() {
-  Operation *op = getOperation();
-  auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
-  if (!arrayAttr)
-    return failure();
-  auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
-  if (!symbolRefAttr)
-    return failure();
-  return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
-}
-
-//===----------------------------------------------------------------------===//
-// ShardingInterface::setShardingOptionAttr
-//===----------------------------------------------------------------------===//
-
-void mesh::ShardingInterface::setShardingOptionAttr(
-    Builder &b, const ShardingOption &option) {
-  if (option.empty)
-    return;
-  Operation *op = getOperation();
-  ArrayAttr shardingArrayAttr = b.getArrayOfI32ArrayAttr(option.shardingArray);
-  op->setDiscardableAttr(getMeshClusterName(), option.cluster);
-  op->setDiscardableAttr(getShardingArrayName(), shardingArrayAttr);
-}
-
 //===----------------------------------------------------------------------===//
 // detail::defaultGetShardingOption
 //===----------------------------------------------------------------------===//
@@ -295,14 +258,7 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
 
 FailureOr<ShardingOption>
 mesh::detail::defaultGetShardingOption(Operation *op) {
-
-  // 1. If a valid sharding attribute exists, use it.
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
-  FailureOr<ShardingOption> shardingOptionFromAttr =
-      shardingOp.getShardingOptionFromAttr();
-  if (succeeded(shardingOptionFromAttr))
-    return shardingOptionFromAttr;
-
   ShardingOption shardingOption;
 
   if (failed(shardingOp.verifyShardingInterfaceImpl()))
@@ -316,7 +272,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
 
-  // 2. Fill sharding option based on op results
+  // 1. Fill sharding option based on op results
   for (OpResult result : op->getResults()) {
     AffineMap map = maps[numOperands + result.getResultNumber()];
     FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
@@ -355,7 +311,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
     }
   }
 
-  // 3. Fill sharding option based on operands
+  // 2. Fill sharding option based on operands
   for (OpOperand &opOperand : op->getOpOperands()) {
     FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
         getMeshShardingAttr(opOperand);
@@ -413,7 +369,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
     }
   }
 
-  // 4. Finalize sharding option
+  // 3. Finalize sharding option
   if (!partialMeshAxes.empty()) {
     bool anyNonEmptyReductionLoop = llvm::any_of(
         llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 03bc3953741bccf..80ee9dcfbc9b16a 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -57,7 +57,6 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   // sharding info is empty, return immediately
   if (shardingOption->empty)
     return success();
-  shardingOp.setShardingOptionAttr(builder, *shardingOption);
 
   if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
     op->emitOpError() << "fail to set sharding annotations.";
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index a81fb3a8db23cfa..bda407b52bfd4f2 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -137,31 +137,52 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
 
 // https://arxiv.org/abs/2211.05102 Figure 2(a)
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
 func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
   %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32: 0>]}
+  // CHECK: %[[V0:.*]] = tosa.matmul
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32: 0>]}
+  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
   %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = [array<i32>, array<i32>, array<i32>, array<i32: 0>]}
+  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V6:.*]] = mesh.shard %[[ARG2]] to <@mesh_1d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
   %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_1d, {{\[\[}}], [], []], partial = sum[0]> : tensor<2x4x8xf32>
   %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
   %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: return %[[V9]]
   return %5 : tensor<2x4x8xf32>
 }
 
 // https://arxiv.org/abs/2211.05102 Figure 2(b)
 // CHECK-LABEL: func.func @mlp_2d_weight_stationary
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
 func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  // CHECK-DAG: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> : tensor<2x4x8xf32>
   %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 1, 2>, array<i32: 0>]}
+  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_3d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[ARG1]] to <@mesh_3d, {{\[\[}}], [0], [1, 2]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_3d,  {{\[\[}}], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
   %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
-  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 1, 2>]}
+  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
   %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = [array<i32>, array<i32>, array<i32: 0>, array<i32: 1, 2>]}
+  // CHECK-DAG: %[[V7:.*]] = mesh.shard %[[V6]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[ARG2]] to <@mesh_3d, {{\[\[}}], [1, 2], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
   %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V11:.*]] = mesh.shard %[[V10]] to <@mesh_3d, {{\[\[}}], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
   %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V12:.*]] = mesh.shard %[[V11]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
   %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: return %[[V12]]
   return %6 : tensor<2x4x8xf32>
 }

>From 50205fc60db885ebee8d622d71a1e0d1f9a8bcf9 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Sat, 21 Oct 2023 21:47:34 +0000
Subject: [PATCH 6/9] clang format

---
 .../Mesh/Transforms/ShardingPropagation.cpp      | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 80ee9dcfbc9b16a..26804d1e290b29f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -81,16 +81,14 @@ struct ShardingPropagation
     }
     Block &block = region.front();
 
-    // clang-format off
     LLVM_DEBUG(
-      DBGS() << "print all the ops' iterator types and indexing maps in the "
-                "block.\n";
-      for (Operation &op : block.getOperations()) {
-        if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
-        shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
-      }
-    );
-    // clang-format on
+        DBGS() << "print all the ops' iterator types and indexing maps in the "
+                  "block.\n";
+        for (Operation &op
+             : block.getOperations()) {
+          if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
+            shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
+        });
 
     // 1. propagate in reversed order
     for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))

>From 2d1c8c699a644ec75f425bb9e64eb1e1cc6a5ae0 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Mon, 23 Oct 2023 23:52:13 +0000
Subject: [PATCH 7/9] fix comments, 3rd

---
 .../Mesh/Interfaces/ShardingInterface.cpp      | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 45039f6cb10fe87..70d5dca4b5c8cc4 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -106,10 +106,9 @@ checkOperandAffineExprRecursively(AffineExpr expr,
     AffineExpr lhs = binOpExpr.getLHS();
     AffineExpr rhs = binOpExpr.getRHS();
     AffineExpr dimExpr;
-    if (lhs.getKind() == AffineExprKind::DimId) {
+    if (lhs.getKind() == AffineExprKind::DimId &&
+        rhs.getKind() == AffineExprKind::Constant) {
       dimExpr = lhs;
-      if (rhs.getKind() != AffineExprKind::Constant)
-        return failure();
     } else if (rhs.getKind() == AffineExprKind::DimId &&
                lhs.getKind() == AffineExprKind::Constant) {
       dimExpr = rhs;
@@ -275,7 +274,8 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
   // 1. Fill sharding option based on op results
   for (OpResult result : op->getResults()) {
     AffineMap map = maps[numOperands + result.getResultNumber()];
-    FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
+    FailureOr<MeshShardingAttr> shardAttr =
+        getMeshShardingAttr(result, /*useOperandSharding*/ true);
     if (failed(shardAttr))
       continue;
     anyShardingInResultsOrOperands = true;
@@ -324,11 +324,11 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
     AffineMap map = maps[opOperand.getOperandNumber()];
     unsigned numDims = map.getNumDims();
 
-    // Handle the split axes, and partial axes don't need to be handled because
-    // they only affect the defining op of the operand
+    // Handle the split axes. Partial axes don't need to be handled because they
+    // only affect the defining op of the operand.
     //
     // TODO: Change to process the operands with single loop index first and
-    // then the operands with multiple loop indices
+    // then the operands with multiple loop indices.
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
       ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
@@ -411,7 +411,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
                                 ArrayRef<IteratorType> loopTypes) {
-  if (succeeded(getMeshShardingAttr(result, false)))
+  if (succeeded(getMeshShardingAttr(result, /*useOperandSharding*/ false)))
     return success();
 
   auto resultType = result.getType().cast<RankedTensorType>();
@@ -421,6 +421,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
   // process the split axes
   for (auto it : llvm::enumerate(map.getResults())) {
     AffineExpr expr = it.value();
+    // `expr` must be an `AffineDimExpr` because `map` is verified by
+    // isProjectedPermutation
     auto dim = expr.cast<AffineDimExpr>();
     unsigned loopIdx = dim.getPosition();
     if (loopIdx < shardingOption.shardingArray.size())

>From d254e84305fb3963d1de400b53de155442bf0235 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Wed, 25 Oct 2023 02:01:47 +0000
Subject: [PATCH 8/9] change getShardingOption signature

---
 .../Mesh/Interfaces/ShardingInterface.h       |  15 +-
 .../Mesh/Interfaces/ShardingInterface.td      |   7 +-
 .../Mesh/Interfaces/ShardingInterface.cpp     | 180 +++++++++---------
 .../Mesh/Transforms/ShardingPropagation.cpp   |  99 +++++++++-
 4 files changed, 199 insertions(+), 102 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index e26d3a395280e7d..5e79bd255ee5221 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -35,9 +35,22 @@ struct ShardingOption {
       : shardingArray(std::move(shardingArray)), cluster(cluster) {}
 };
 
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+FailureOr<MeshShardingAttr> getMeshShardingAttr(OpResult result,
+                                                bool useOperandSharding);
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand);
+
 namespace detail {
 
-FailureOr<ShardingOption> defaultGetShardingOption(Operation *op);
+FailureOr<ShardingOption>
+defaultGetShardingOption(Operation *op,
+                         ArrayRef<MeshShardingAttr> operandShardings,
+                         ArrayRef<MeshShardingAttr> resultShardings);
 
 LogicalResult
 defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index a76fcce258dabb5..21b6c8d4f599a8d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -61,11 +61,14 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         }],
         /*retTy=*/"FailureOr<ShardingOption>",
         /*methodName=*/"getShardingOption",
-        /*args=*/(ins),
+        /*args=*/(ins
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings
+        ),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return detail::defaultGetShardingOption(
-            $_op.getOperation());
+            $_op.getOperation(), operandShardings, resultShardings);
         }]
       >,
       InterfaceMethod<
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 70d5dca4b5c8cc4..f38001212a43d71 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -29,64 +29,6 @@ using namespace mlir::mesh;
 // common util functions
 //===----------------------------------------------------------------------===//
 
-// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
-// for a given operation result.
-static FailureOr<MeshShardingAttr>
-getMeshShardingAttr(OpResult result, bool useOperandSharding) {
-  Value val = result.cast<Value>();
-  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
-    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
-    if (!shardOp)
-      return false;
-    return !shardOp.getAnnotateForUsers();
-  });
-
-  if (anyShardedForDef) {
-    // expected to have exact one use if it has a use of `mesh.shard` without
-    // unit attr annotate_for_users
-    if (!val.hasOneUse())
-      return failure();
-    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
-    return shardOp.getShard();
-  } else if (useOperandSharding) {
-    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
-      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
-      if (!shardOp)
-        return false;
-      return shardOp.getAnnotateForUsers();
-    });
-    if (anyShardedForUsers) {
-      SmallVector<ShardOp> shardOps;
-      for (Operation *user : val.getUsers()) {
-        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
-        if (shardOp)
-          shardOps.push_back(shardOp);
-      }
-      MeshShardingAttr shardForDef = shardOps[0].getShard();
-      for (size_t i = 1; i < shardOps.size(); ++i) {
-        // TODO: Deduce a reasonable mesh sharding attr for def when they are
-        // different
-        assert(shardOps[i].getShard() == shardForDef &&
-               "only support all shard ops have the same mesh sharding attr");
-      }
-      return shardForDef;
-    }
-  }
-
-  return failure();
-}
-
-// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
-// for a given operation operand.
-static FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpOperand &opOperand) {
-  Value val = opOperand.get();
-  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
-    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
-
-  return failure();
-}
-
 static LogicalResult
 checkOperandAffineExprRecursively(AffineExpr expr,
                                   SmallVectorImpl<bool> &seenIds) {
@@ -146,6 +88,64 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
   return positions;
 }
 
+//===----------------------------------------------------------------------===//
+// mesh::getMeshShardingAttr
+//===----------------------------------------------------------------------===//
+
+FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
+                                                      bool useOperandSharding) {
+  Value val = result.cast<Value>();
+  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return !shardOp.getAnnotateForUsers();
+  });
+
+  if (anyShardedForDef) {
+    // expected to have exact one use if it has a use of `mesh.shard` without
+    // unit attr annotate_for_users
+    if (!val.hasOneUse())
+      return failure();
+    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+    return shardOp.getShard();
+  } else if (useOperandSharding) {
+    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+      if (!shardOp)
+        return false;
+      return shardOp.getAnnotateForUsers();
+    });
+    if (anyShardedForUsers) {
+      SmallVector<ShardOp> shardOps;
+      for (Operation *user : val.getUsers()) {
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+        if (shardOp)
+          shardOps.push_back(shardOp);
+      }
+      MeshShardingAttr shardForDef = shardOps[0].getShard();
+      for (size_t i = 1; i < shardOps.size(); ++i) {
+        // TODO: Deduce a reasonable mesh sharding attr for def when they are
+        // different
+        assert(shardOps[i].getShard() == shardForDef &&
+               "only support all shard ops have the same mesh sharding attr");
+      }
+      return shardForDef;
+    }
+  }
+
+  return failure();
+}
+
+FailureOr<std::pair<bool, MeshShardingAttr>>
+mesh::getMeshShardingAttr(OpOperand &opOperand) {
+  Value val = opOperand.get();
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
+    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // ShardingInterface::verifyShardingInterfaceImpl
 //===----------------------------------------------------------------------===//
@@ -214,19 +214,18 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
 namespace {
 
 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
-static LogicalResult
-fillShardingOption(Operation *op, ShardingOption &shardingOption,
-                   SymbolRefAttr cluster, ArrayRef<int32_t> meshAxes,
-                   unsigned loopIdx, bool ignoreIfConflicted = false) {
+static LogicalResult fillShardingOption(Operation *op,
+                                        ShardingOption &shardingOption,
+                                        SymbolRefAttr cluster,
+                                        ArrayRef<int32_t> meshAxes,
+                                        unsigned loopIdx) {
   if ((shardingOption.cluster && cluster &&
        shardingOption.cluster != cluster) ||
       (!shardingOption.shardingArray[loopIdx].empty() &&
        shardingOption.shardingArray[loopIdx] != meshAxes)) {
-    if (ignoreIfConflicted)
-      return success();
-    else
-      return op->emitOpError()
-             << "sharding option conflicts on loop iterator " << loopIdx;
+    LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
+                      << loopIdx << "\n");
+    return failure();
   }
   for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
     if (i == loopIdx)
@@ -236,12 +235,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
       if (std::find(shardingOption.shardingArray[i].begin(),
                     shardingOption.shardingArray[i].end(),
                     axis) != shardingOption.shardingArray[i].end()) {
-        if (ignoreIfConflicted)
-          return success();
-        else
-          return op->emitOpError()
-                 << "sharding option conflicts because mesh axes " << axis
-                 << " duplicate";
+        LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
+                          << axis << " duplicate");
+        return failure();
       }
     }
   }
@@ -255,8 +251,9 @@ fillShardingOption(Operation *op, ShardingOption &shardingOption,
 
 } // namespace
 
-FailureOr<ShardingOption>
-mesh::detail::defaultGetShardingOption(Operation *op) {
+FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
+    Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   ShardingOption shardingOption;
 
@@ -272,35 +269,34 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
   bool anyShardingInResultsOrOperands = false;
 
   // 1. Fill sharding option based on op results
-  for (OpResult result : op->getResults()) {
-    AffineMap map = maps[numOperands + result.getResultNumber()];
-    FailureOr<MeshShardingAttr> shardAttr =
-        getMeshShardingAttr(result, /*useOperandSharding*/ true);
-    if (failed(shardAttr))
+  for (auto shardingIt : llvm::enumerate(resultShardings)) {
+    MeshShardingAttr shardAttr = shardingIt.value();
+    if (!shardAttr)
       continue;
+    AffineMap map = maps[numOperands + shardingIt.index()];
     anyShardingInResultsOrOperands = true;
     // Handle the split axes: calculate the corresponding loop index for each
     // split axes sub-array, and then store the sub-array to
     // shardingOption[index]
-    for (auto it : llvm::zip(map.getResults(), shardAttr->getSplitAxes())) {
+    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
       ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
       auto dim = expr.cast<AffineDimExpr>();
       unsigned index = dim.getPosition();
       visitedLoopIndices.insert(index);
-      if (failed(fillShardingOption(op, shardingOption, shardAttr->getCluster(),
+      if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
                                     axes, index)))
         return failure();
     }
 
     // Handle the partial axes: at this stage, the exact loop index/indices
     // cannot be decided because there could be multiple reduction loops.
-    ArrayRef<int32_t> partialAxes = shardAttr->getPartialAxes();
+    ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
     if (!partialAxes.empty()) {
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
                                     "supported at present";
-      partialType = shardAttr->getPartialType();
+      partialType = shardAttr.getPartialType();
       partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
       // Add all the reduction loop indices to `visitedLoopIndices` if
       // `partialAxes` is not empty
@@ -312,16 +308,13 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
   }
 
   // 2. Fill sharding option based on operands
-  for (OpOperand &opOperand : op->getOpOperands()) {
-    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
-        getMeshShardingAttr(opOperand);
-    if (failed(maybeShardAttr))
+  for (auto shardingIt : llvm::enumerate(operandShardings)) {
+    MeshShardingAttr shardAttr = shardingIt.value();
+    if (!shardAttr)
       continue;
 
     anyShardingInResultsOrOperands = true;
-    bool annotateForUsers = maybeShardAttr->first;
-    MeshShardingAttr shardAttr = maybeShardAttr->second;
-    AffineMap map = maps[opOperand.getOperandNumber()];
+    AffineMap map = maps[shardingIt.index()];
     unsigned numDims = map.getNumDims();
 
     // Handle the split axes. Partial axes don't need to be handled because they
@@ -344,8 +337,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
         unsigned loopIdx = *loopIndices->begin();
         visitedLoopIndices.insert(loopIdx);
         if (failed(fillShardingOption(op, shardingOption,
-                                      shardAttr.getCluster(), axes, loopIdx,
-                                      !annotateForUsers)))
+                                      shardAttr.getCluster(), axes, loopIdx)))
           return failure();
       }
       // If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,7 +353,7 @@ mesh::detail::defaultGetShardingOption(Operation *op) {
         }
         if (!seenLoopIndices)
           return op->emitOpError()
-                 << "the operand " << opOperand.getOperandNumber()
+                 << "the operand " << shardingIt.index()
                  << " has multiple loop indices in a dimension, but none of "
                     "them could be found in the exactly specified annotation "
                     "of op results or operands.";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 26804d1e290b29f..9da8955c135e976 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -34,6 +34,48 @@ namespace {
 // Utilities
 //===----------------------------------------------------------------------===//
 
+// This method returns all possible sharding attributes. For example,
+// mustShardings = [shard0, None] and optionalShardings = [None, shard1], the
+// result will be [[shard0, shard1], [shard0, None]]
+static SmallVector<SmallVector<MeshShardingAttr>>
+getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
+                                ArrayRef<MeshShardingAttr> optionalShardings) {
+  SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
+  SmallVector<MeshShardingAttr> curShardingAttrs;
+
+  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
+    if (i == mustShardings.size()) {
+      allShardingAttrs.push_back(
+          SmallVector<MeshShardingAttr>(curShardingAttrs));
+      return;
+    }
+
+    if (mustShardings[i]) {
+      curShardingAttrs.push_back(mustShardings[i]);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      return;
+    }
+
+    if (optionalShardings[i]) {
+      curShardingAttrs.push_back(optionalShardings[i]);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      curShardingAttrs.push_back(nullptr);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      return;
+    }
+
+    curShardingAttrs.push_back(nullptr);
+    dfsCreateShardingAttrs(i + 1);
+    curShardingAttrs.pop_back();
+  };
+
+  dfsCreateShardingAttrs(0);
+  return allShardingAttrs;
+}
+
 // For each operation that implements the ShardingInterface, infer the sharding
 // option of the operation from its operands and/or results using the
 // `getShardingOption` method. If the inferred sharding option is not empty, add
@@ -49,16 +91,63 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     return failure();
   }
 
-  FailureOr<ShardingOption> shardingOption = shardingOp.getShardingOption();
-  if (failed(shardingOption)) {
-    op->emitOpError() << "fail to get sharding option from results.";
+  // collect MeshShardingAttr from results
+  SmallVector<MeshShardingAttr> resultShardings;
+  resultShardings.reserve(op->getNumResults());
+  for (OpResult result : op->getResults()) {
+    FailureOr<MeshShardingAttr> shardAttr =
+        getMeshShardingAttr(result, /*useOperandSharding*/ true);
+    if (succeeded(shardAttr))
+      resultShardings.push_back(*shardAttr);
+    else
+      resultShardings.push_back(nullptr);
+  }
+
+  // collect MeshShardingAttr from operands
+  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+  allowConflictsOperandShardings.resize(op->getNumOperands());
+  SmallVector<MeshShardingAttr> operandMustShardings;
+  operandMustShardings.resize(op->getNumOperands());
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(opOperand);
+    if (failed(maybeShardAttr))
+      continue;
+
+    bool annotateForUsers = maybeShardAttr->first;
+    if (annotateForUsers)
+      operandMustShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+    else
+      allowConflictsOperandShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+  }
+
+  // try to get the sharding option
+  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+      getOrderedPossibleShardingAttrs(operandMustShardings,
+                                      allowConflictsOperandShardings);
+  FailureOr<ShardingOption> finalShardingOption = failure();
+  for (ArrayRef<MeshShardingAttr> operandShardings :
+       possibleOperandShardingAttrs) {
+    FailureOr<ShardingOption> shardingOption =
+        shardingOp.getShardingOption(operandShardings, resultShardings);
+    if (succeeded(shardingOption)) {
+      finalShardingOption = shardingOption;
+      break;
+    }
+  }
+
+  if (failed(finalShardingOption)) {
+    op->emitOpError() << "fail to get sharding option.";
     return failure();
   }
   // sharding info is empty, return immediately
-  if (shardingOption->empty)
+  if (finalShardingOption->empty)
     return success();
 
-  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
+  if (failed(
+          shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
     op->emitOpError() << "fail to set sharding annotations.";
     return failure();
   }

>From b58581b18132f71cda9b56a9cf6884dceab36f07 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Thu, 26 Oct 2023 01:55:36 +0000
Subject: [PATCH 9/9] minor fix

---
 .../Mesh/Interfaces/ShardingInterface.h       | 12 ++--
 .../Mesh/Interfaces/ShardingInterface.cpp     | 55 ++++++++++---------
 .../Mesh/Transforms/ShardingPropagation.cpp   | 49 +++++++++++------
 3 files changed, 65 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 5e79bd255ee5221..d860628cf371aa9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -35,13 +35,13 @@ struct ShardingOption {
       : shardingArray(std::move(shardingArray)), cluster(cluster) {}
 };
 
-// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
-// for a given operation result.
-FailureOr<MeshShardingAttr> getMeshShardingAttr(OpResult result,
-                                                bool useOperandSharding);
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// result and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpResult result);
 
-// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
-// for a given operation operand.
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// operand and includes the 'annotate_for_users' information.
 FailureOr<std::pair<bool, MeshShardingAttr>>
 getMeshShardingAttr(OpOperand &opOperand);
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index f38001212a43d71..c2e1d1c726816a5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -92,8 +92,8 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
 // mesh::getMeshShardingAttr
 //===----------------------------------------------------------------------===//
 
-FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
-                                                      bool useOperandSharding) {
+FailureOr<std::pair<bool, MeshShardingAttr>>
+mesh::getMeshShardingAttr(OpResult result) {
   Value val = result.cast<Value>();
   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
@@ -108,32 +108,31 @@ FailureOr<MeshShardingAttr> mesh::getMeshShardingAttr(OpResult result,
     if (!val.hasOneUse())
       return failure();
     auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
-    return shardOp.getShard();
-  } else if (useOperandSharding) {
-    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
-      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
-      if (!shardOp)
-        return false;
-      return shardOp.getAnnotateForUsers();
-    });
-    if (anyShardedForUsers) {
-      SmallVector<ShardOp> shardOps;
-      for (Operation *user : val.getUsers()) {
-        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
-        if (shardOp)
-          shardOps.push_back(shardOp);
-      }
-      MeshShardingAttr shardForDef = shardOps[0].getShard();
-      for (size_t i = 1; i < shardOps.size(); ++i) {
-        // TODO: Deduce a reasonable mesh sharding attr for def when they are
-        // different
-        assert(shardOps[i].getShard() == shardForDef &&
-               "only support all shard ops have the same mesh sharding attr");
-      }
-      return shardForDef;
-    }
+    return std::make_pair(false, shardOp.getShard());
   }
 
+  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return shardOp.getAnnotateForUsers();
+  });
+  if (anyShardedForUsers) {
+    SmallVector<ShardOp> shardOps;
+    for (Operation *user : val.getUsers()) {
+      ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+      if (shardOp)
+        shardOps.push_back(shardOp);
+    }
+    MeshShardingAttr shardForDef = shardOps[0].getShard();
+    for (size_t i = 1; i < shardOps.size(); ++i) {
+      // TODO: Deduce a reasonable mesh sharding attr for def when they are
+      // different
+      assert(shardOps[i].getShard() == shardForDef &&
+             "only support all shard ops have the same mesh sharding attr");
+    }
+    return std::make_pair(true, shardForDef);
+  }
   return failure();
 }
 
@@ -403,7 +402,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
                                 ArrayRef<IteratorType> loopTypes) {
-  if (succeeded(getMeshShardingAttr(result, /*useOperandSharding*/ false)))
+  FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
+      getMeshShardingAttr(result);
+  if (succeeded(maybeSharding) && !maybeSharding->first)
     return success();
 
   auto resultType = result.getType().cast<RankedTensorType>();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9da8955c135e976..3aed912fb43c63e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -34,9 +34,10 @@ namespace {
 // Utilities
 //===----------------------------------------------------------------------===//
 
-// This method returns all possible sharding attributes. For example,
-// mustShardings = [shard0, None] and optionalShardings = [None, shard1], the
-// result will be [[shard0, shard1], [shard0, None]]
+// This method retrieves all potential sharding attributes, prioritizing
+// specific shardings. For example, mustShardings = [shard0, None] and
+// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
+// [shard0, None]]
 static SmallVector<SmallVector<MeshShardingAttr>>
 getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
                                 ArrayRef<MeshShardingAttr> optionalShardings) {
@@ -92,15 +93,20 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   }
 
   // collect MeshShardingAttr from results
-  SmallVector<MeshShardingAttr> resultShardings;
-  resultShardings.reserve(op->getNumResults());
+  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
+  allowConflictsResultShardings.resize(op->getNumResults());
+  SmallVector<MeshShardingAttr> resultMustShardings;
+  resultMustShardings.resize(op->getNumResults());
   for (OpResult result : op->getResults()) {
-    FailureOr<MeshShardingAttr> shardAttr =
-        getMeshShardingAttr(result, /*useOperandSharding*/ true);
-    if (succeeded(shardAttr))
-      resultShardings.push_back(*shardAttr);
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(result);
+    if (failed(maybeShardAttr))
+      continue;
+    if (!maybeShardAttr->first)
+      resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
     else
-      resultShardings.push_back(nullptr);
+      allowConflictsResultShardings[result.getResultNumber()] =
+          maybeShardAttr->second;
   }
 
   // collect MeshShardingAttr from operands
@@ -114,8 +120,7 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     if (failed(maybeShardAttr))
       continue;
 
-    bool annotateForUsers = maybeShardAttr->first;
-    if (annotateForUsers)
+    if (maybeShardAttr->first)
       operandMustShardings[opOperand.getOperandNumber()] =
           maybeShardAttr->second;
     else
@@ -127,14 +132,22 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
       getOrderedPossibleShardingAttrs(operandMustShardings,
                                       allowConflictsOperandShardings);
+  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
+      getOrderedPossibleShardingAttrs(resultMustShardings,
+                                      allowConflictsResultShardings);
   FailureOr<ShardingOption> finalShardingOption = failure();
-  for (ArrayRef<MeshShardingAttr> operandShardings :
-       possibleOperandShardingAttrs) {
-    FailureOr<ShardingOption> shardingOption =
-        shardingOp.getShardingOption(operandShardings, resultShardings);
-    if (succeeded(shardingOption)) {
-      finalShardingOption = shardingOption;
+  for (ArrayRef<MeshShardingAttr> resultShardings :
+       possibleResultShardingAttrs) {
+    if (succeeded(finalShardingOption))
       break;
+    for (ArrayRef<MeshShardingAttr> operandShardings :
+         possibleOperandShardingAttrs) {
+      FailureOr<ShardingOption> shardingOption =
+          shardingOp.getShardingOption(operandShardings, resultShardings);
+      if (succeeded(shardingOption)) {
+        finalShardingOption = shardingOption;
+        break;
+      }
     }
   }
 



More information about the Mlir-commits mailing list