[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)

Chengji Yao llvmlistbot at llvm.org
Thu Oct 19 17:41:37 PDT 2023


https://github.com/yaochengji created https://github.com/llvm/llvm-project/pull/69665

- add `ShardingInterface` and the methods' default implementation
- add `ShardingInterface` implementation for element-wise and matmul ops in TOSA dialect
- add sharding propagation pass

>From 25d68a6caf9c32747ab011b6d955d914573c2f55 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Wed, 18 Oct 2023 18:54:49 +0000
Subject: [PATCH] [MLIR][Mesh] Add sharding propagation pass

---
 mlir/include/mlir/Dialect/Mesh/CMakeLists.txt |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  22 +
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |   4 +
 .../Mesh/Interfaces/ShardingInterface.h       |  58 ++
 .../Mesh/Interfaces/ShardingInterface.td      |  87 +++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   6 +
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |  41 ++
 .../mlir/Dialect/Mesh/Transforms/Passes.td    |  33 ++
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.h   |  23 +
 .../mlir/Dialect/Utils/IndexingUtils.h        |   3 +
 mlir/include/mlir/IR/AffineMap.h              |  12 +
 mlir/include/mlir/IR/Builders.h               |   1 +
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/include/mlir/InitAllPasses.h             |   2 +
 mlir/lib/Dialect/Mesh/CMakeLists.txt          |   2 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  32 +-
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |  15 +
 .../Mesh/Interfaces/ShardingInterface.cpp     | 544 ++++++++++++++++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |  17 +
 .../Mesh/Transforms/ShardingPropagation.cpp   | 155 +++++
 mlir/lib/Dialect/Tosa/CMakeLists.txt          |  14 +
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 111 ++++
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      |  13 +
 mlir/lib/IR/AffineMap.cpp                     |  13 +
 mlir/lib/IR/Builders.cpp                      |   9 +
 .../Dialect/Mesh/sharding-propagation.mlir    | 167 ++++++
 27 files changed, 1421 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
 create mode 100644 mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"Parallel", 1, "parallel">,
+  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+  I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::mesh";
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
        $partial_axes^ `]`)? `>`
   }];
 
+  let builders = [
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
+                     "ArrayRef<int32_t>": $partial_axes,
+                     "mesh::Partial": $partial_type), [{
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+        llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+          return DenseI32ArrayAttr::get($_ctxt, array);
+      }));
+      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+                   partial_type);
+    }]>,
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+    }]>
+  ];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+  for (int64_t i = array.size() - 1; i >= 0; i--) {
+    if (array[i].empty())
+      array.pop_back();
+    else
+      break;
+  }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+  // An array of int array. The sub-array at the i-th position signifies the
+  // mesh axes the i-th loop will be sharded on.
+  ShardingArray shardingArray;
+  SymbolRefAttr cluster;
+  // `empty` is true indicates that no sharding infomation can be inferred at
+  // present. Note that it is different from that an operation is not sharded.
+  bool empty = false;
+  ShardingOption() = default;
+  ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+      : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+                              const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+    let description = [{
+        Interface for allowing operations to expose information needed to
+        shard them.
+    }];
+    let cppNamespace = "::mlir::mesh";
+
+    let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of iterator types that describe the number of loops.
+        }],
+        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*methodName=*/"getLoopIteratorTypes",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the indexing maps attribute within the current operation.
+        }],
+        /*retTy=*/"SmallVector<AffineMap>",
+        /*methodName=*/"getIndexingMaps",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Given that certain operands or results of the operation may have
+          sharding annotations, this method leverages this information to deduce
+          how the operation should be sharded.
+        }],
+        /*retTy=*/"FailureOr<ShardingOption>",
+        /*methodName=*/"getShardingOption",
+        /*args=*/(ins
+          "OpBuilder &":$b
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingOption(
+            $_op.getOperation(), b);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, this method adds `mesh.shard`
+          operations for the operands and results that previously lacked
+          sharding annotations.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"addShardingAnnotations",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultAddShardingAnnotations(
+            $_op.getOperation(), b, shardingOption);
+        }]
+      >
+    ];
+
+    let extraClassDeclaration = [{
+      LogicalResult verifyShardingInterfaceImpl();
+
+      void printLoopTypesAndIndexingMaps(raw_ostream &os);
+    }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+  let summary = "sharding propagation";
+  let description = [{
+    Propagates sharding information throughout the graph. After this pass, each
+    of the operations' operands and results is annotated with a `mesh.shard`
+    operation, and the operations themselves are added with sharding option
+    attributes.
+  }];
+  let constructor = "mlir::mesh::createShardingPropagationPass()";
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
 
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
 /// Compute linear index from provided strides and indices, assuming strided
 /// layout.
 /// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
 
+  /// Returns an affine map with `numDims` input dimensions and results
+  /// specified by `targets`.
+  ///
+  /// Examples:
+  /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+  /// * getMultiDimMapWithTargets(3, [2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
+  static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+                                      ArrayRef<int64_t> targets,
+                                      MLIRContext *context);
+
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
   /// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
   ArrayAttr getTypeArrayAttr(TypeRange values);
+  ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
 
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+  return (partial == Partial::Generic &&
+          iType == IteratorType::ReductionGeneric) ||
+         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+         (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+  switch (iType) {
+  case IteratorType::ReductionGeneric:
+    return Partial::Generic;
+  case IteratorType::ReductionSum:
+    return Partial::Sum;
+  case IteratorType::ReductionMax:
+    return Partial::Max;
+  case IteratorType::ReductionMin:
+    return Partial::Min;
+  default:
+    assert(0 && "No corresponding partial type can be found");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.cluster op
 //===----------------------------------------------------------------------===//
@@ -95,7 +126,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   }
   if (failed(checkMeshAxis(partialAxes)))
     return failure();
-
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..1010756f1fe279a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRShardingInterface
+  ShardingInterface.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRShardingInterfaceIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRIR
+  MLIRMeshDialect
+  MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
new file mode 100644
index 000000000000000..7d5c73851bb1f4e
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -0,0 +1,544 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+static FailureOr<ShardingOption> getShardingOptionFromAttr(Operation *op) {
+  auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
+  if (!arrayAttr)
+    return failure();
+  auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
+  if (!symbolRefAttr)
+    return failure();
+  return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+  Value val = result.cast<Value>();
+  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return !shardOp.getAnnotateForUsers();
+  });
+
+  if (anyShardedForDef) {
+    assert(val.hasOneUse() &&
+           "expected to has exact one use if it has a use of mesh.shard "
+           "without unit attr annotate_for_users");
+    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+    return shardOp.getShard();
+  } else if (useOperandSharding) {
+    bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+      auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+      if (!shardOp)
+        return false;
+      return shardOp.getAnnotateForUsers();
+    });
+    if (anyShardedForUsers) {
+      SmallVector<ShardOp> shardOps;
+      for (Operation *user : val.getUsers()) {
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+        if (shardOp)
+          shardOps.push_back(shardOp);
+      }
+      MeshShardingAttr shardForDef = shardOps[0].getShard();
+      for (size_t i = 1; i < shardOps.size(); ++i) {
+        // TODO: Deduce a reasonable mesh sharding attr for def when they are
+        // different
+        assert(shardOps[i].getShard() == shardForDef &&
+               "only support all shard ops have the same mesh sharding attr");
+      }
+      return shardForDef;
+    }
+  }
+
+  return failure();
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+static FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand) {
+  Value val = opOperand.get();
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) {
+    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+  }
+
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+                                  SmallVectorImpl<bool> &seenIds) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Add: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+      return failure();
+    if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+      return failure();
+    return success();
+  }
+  case AffineExprKind::Mul: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    AffineExpr dimExpr;
+    if (lhs.getKind() == AffineExprKind::DimId) {
+      dimExpr = lhs;
+      if (rhs.getKind() != AffineExprKind::Constant)
+        return failure();
+    } else if (rhs.getKind() == AffineExprKind::DimId &&
+               lhs.getKind() == AffineExprKind::Constant) {
+      dimExpr = rhs;
+    } else
+      return failure();
+    unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  case AffineExprKind::DimId: {
+    unsigned position = expr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  default:
+    return failure();
+  }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+  SmallVector<bool> seenIds(numDims, false);
+  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+    return failure();
+
+  llvm::SmallSet<unsigned, 2> positions;
+  for (auto it : llvm::enumerate(seenIds)) {
+    if (it.value())
+      positions.insert((unsigned)it.index());
+  }
+  return positions;
+}
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+  Operation *op = getOperation();
+
+  // check operands and results type
+  for (Type type : op->getOperandTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+  for (Type type : op->getResultTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+
+  // check loop types
+  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  if (loopTypes.size() == 0)
+    return failure();
+
+  // check maps
+  SmallVector<AffineMap> maps = getIndexingMaps();
+  if (maps.size() == 0)
+    return failure();
+  unsigned numOperands = op->getNumOperands();
+  unsigned numResults = op->getNumResults();
+  if (numOperands + numResults != maps.size())
+    return failure();
+
+  for (OpResult result : op->getResults()) {
+    auto resultType = result.getType().dyn_cast<RankedTensorType>();
+    if (!resultType)
+      return failure();
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    if (!map.isProjectedPermutation()) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::printLoopTypesAndIndexingMaps
+//===----------------------------------------------------------------------===//
+
+void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+  os << "print loop types and indexing maps for: \n";
+  getOperation()->print(os);
+  os << "\n";
+  os << "loop types: [";
+  for (IteratorType type : getLoopIteratorTypes()) {
+    os << stringifyEnum(type) << " ";
+  }
+  os << "]\n";
+  os << "indexing maps: \n";
+  for (AffineMap map : getIndexingMaps())
+    os << map << "\n";
+  os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultGetShardingOption
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+static LogicalResult
+fillShardingOption(Operation *op, ShardingOption &shardingOption,
+                   SymbolRefAttr cluster, ArrayRef<int32_t> meshAxes,
+                   unsigned loopIdx, bool ignoreIfConflicted = false) {
+  if ((shardingOption.cluster && cluster &&
+       shardingOption.cluster != cluster) ||
+      (!shardingOption.shardingArray[loopIdx].empty() &&
+       shardingOption.shardingArray[loopIdx] != meshAxes)) {
+    if (ignoreIfConflicted)
+      return success();
+    else
+      return op->emitOpError()
+             << "sharding option confilicts on loop iterator " << loopIdx;
+  }
+  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
+    if (i != loopIdx) {
+      for (int32_t axis : meshAxes)
+        if (std::find(shardingOption.shardingArray[i].begin(),
+                      shardingOption.shardingArray[i].end(),
+                      axis) != shardingOption.shardingArray[i].end()) {
+          if (ignoreIfConflicted)
+            return success();
+          else
+            return op->emitOpError()
+                   << "sharding option confilicts because of mesh axis " << axis
+                   << " duplicates";
+        }
+    }
+  }
+  if (cluster)
+    shardingOption.cluster = cluster;
+  if (shardingOption.shardingArray[loopIdx].empty())
+    shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
+                                                 meshAxes.end());
+  return success();
+}
+
+} // namespace
+
+FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(Operation *op,
+                                                                 OpBuilder &b) {
+
+  // 1. If a valid sharding attribute exists, use it.
+  FailureOr<ShardingOption> shardingOptionFromAttr =
+      getShardingOptionFromAttr(op);
+  if (succeeded(shardingOptionFromAttr))
+    return shardingOptionFromAttr;
+
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  ShardingOption shardingOption;
+
+  if (failed(shardingOp.verifyShardingInterfaceImpl()))
+    return op->emitOpError() << "invalid sharding interface implementation";
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+  shardingOption.shardingArray.resize(loopTypes.size());
+  llvm::SmallVector<int32_t> partialMeshAxes;
+  Partial partialType;
+  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
+  bool anyShardingInResultsOrOperands = false;
+
+  // 2. Fill sharding option based on op results
+  for (OpResult result : op->getResults()) {
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
+    if (failed(shardAttr))
+      continue;
+    anyShardingInResultsOrOperands = true;
+    // Handle the split axes: calculate the corresponding loop index for each
+    // split axes sub-array, and then store the sub-array to
+    // shardingOption[index]
+    for (auto it : llvm::zip(map.getResults(), shardAttr->getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      auto dim = expr.cast<AffineDimExpr>();
+      unsigned index = dim.getPosition();
+      visitedLoopIndices.insert(index);
+      if (failed(fillShardingOption(op, shardingOption, shardAttr->getCluster(),
+                                    axes, index)))
+        return failure();
+    }
+
+    // Handle the partial axes: at this stage, the exact loop index/indices
+    // cannot be decided because there could be multiple reduction loops.
+    ArrayRef<int32_t> partialAxes = shardAttr->getPartialAxes();
+    if (!partialAxes.empty()) {
+      if (!partialMeshAxes.empty())
+        return op->emitOpError() << "at most one result with partial axes is "
+                                    "supported at present";
+      partialType = shardAttr->getPartialType();
+      partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
+      // Add all the reduction loop indices to `visitedLoopIndices` if
+      // `partialAxes` is not empty
+      for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
+        if (isReductionLoop(loopTypes[loopIdx]))
+          visitedLoopIndices.insert(loopIdx);
+      }
+    }
+  }
+
+  // 3. Fill sharding option based on operands
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(opOperand);
+    if (failed(maybeShardAttr))
+      continue;
+
+    anyShardingInResultsOrOperands = true;
+    bool annotateForUsers = maybeShardAttr->first;
+    MeshShardingAttr shardAttr = maybeShardAttr->second;
+    AffineMap map = maps[opOperand.getOperandNumber()];
+    unsigned numDims = map.getNumDims();
+
+    // Handle the split axes, and partial axes don't need to be handled because
+    // they only affect the definig op of the operand
+    //
+    // TODO: Change to process the operands with single loop index first and
+    // then the operands with multiple loop indices
+    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+          checkOperandAffineExpr(expr, numDims);
+      if (failed(loopIndices))
+        return op->emitOpError()
+               << "operand's affine expression is restricted to const_i * "
+                  "dim_i + const_j + dim_j + ...";
+      if (loopIndices->empty())
+        continue;
+      if (loopIndices->size() == 1) {
+        unsigned loopIdx = *loopIndices->begin();
+        visitedLoopIndices.insert(loopIdx);
+        if (failed(fillShardingOption(op, shardingOption,
+                                      shardAttr.getCluster(), axes, loopIdx,
+                                      !annotateForUsers)))
+          return failure();
+      }
+      // If multiple loop indices correspond to a dimension of an operand, it is
+      // difficult to infer which loop indices are responsible for sharding.
+      // Therefore, the exact loop index must be specified by others.
+      if (loopIndices->size() > 1) {
+        bool seenLoopIndices = false;
+        for (unsigned loopIdx : *loopIndices) {
+          if (visitedLoopIndices.contains(loopIdx)) {
+            seenLoopIndices = true;
+            break;
+          }
+        }
+        if (!seenLoopIndices)
+          return op->emitOpError()
+                 << "the operand " << opOperand.getOperandNumber()
+                 << " has multiple loop indices in a dimension, but none of "
+                    "them could be found in the exactly specified annotation "
+                    "of op results or operands.";
+      }
+    }
+  }
+
+  // 4. Finalize sharding option
+  if (!partialMeshAxes.empty()) {
+    bool anyNonEmptyReductionLoop = llvm::any_of(
+        llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
+          SmallVector<int32_t> &subArray = it.value();
+          int64_t idx = it.index();
+          return isReductionLoop(loopTypes[idx]) && !subArray.empty();
+        });
+    if (!anyNonEmptyReductionLoop) {
+      bool filled = false;
+      for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
+        if (isReductionLoop(loopTypes[idx]) &&
+            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+          std::ignore = fillShardingOption(op, shardingOption, nullptr,
+                                           partialMeshAxes, idx);
+          filled = true;
+          break;
+        }
+      }
+      if (!filled)
+        return op->emitOpError() << "no matched reduction loop found for the "
+                                    "result's partial type";
+    }
+  }
+  removeTrailingEmptySubArray(shardingOption.shardingArray);
+  if (!anyShardingInResultsOrOperands)
+    shardingOption.empty = true;
+  return shardingOption;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  if (succeeded(getMeshShardingAttr(result, false)))
+    return success();
+
+  auto resultType = result.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
+  SmallVector<int32_t> partialAxes;
+
+  // process the split axes
+  for (auto it : llvm::enumerate(map.getResults())) {
+    AffineExpr expr = it.value();
+    auto dim = expr.cast<AffineDimExpr>();
+    unsigned loopIdx = dim.getPosition();
+    if (loopIdx < shardingOption.shardingArray.size())
+      splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
+  }
+
+  // process the partial axes
+  Partial partialType;
+  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
+    IteratorType iType = std::get<0>(it);
+    if (isReductionLoop(iType)) {
+      Partial curPartialType = getPartialTypeFromReduction(iType);
+      if (!partialAxes.empty())
+        assert(partialType == curPartialType &&
+               "Only one reduction type is supported");
+      partialType = curPartialType;
+      const SmallVector<int32_t> &axis = std::get<1>(it);
+      partialAxes.append(axis);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
+                            partialAxes, partialType);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointAfterValue(result);
+  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
+                                   shardAttr, /*annotate_for_users*/ false);
+  result.replaceAllUsesExcept(shardOp, shardOp);
+  return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
+  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
+    return success();
+  Value operand = opOperand.get();
+  auto operandType = operand.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+  unsigned numDims = map.getNumDims();
+  for (auto it : llvm::enumerate(map.getResults())) {
+    int64_t idx = it.index();
+    AffineExpr expr = it.value();
+    FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+        checkOperandAffineExpr(expr, numDims);
+    if (failed(loopIndices))
+      return failure();
+    SmallVector<unsigned> shardedLoopIndices;
+    for (unsigned loopIdx : *loopIndices) {
+      if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
+          !shardingOption.shardingArray[loopIdx].empty())
+        shardedLoopIndices.push_back(loopIdx);
+    }
+    // mostly one sharded loop index is accepted
+    if (shardedLoopIndices.size() > 1)
+      return failure();
+    if (shardedLoopIndices.size() == 1) {
+      splitAxes[idx].append(
+          shardingOption.shardingArray[shardedLoopIndices[0]]);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(opOperand.getOwner());
+  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
+                                   shardAttr, true);
+  opOperand.set(shardOp);
+
+  return success();
+}
+
+} // namespace
+
+LogicalResult mesh::detail::defaultAddShardingAnnotations(
+    Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+
+  // 1. add mesh.shard ops for all op results
+  for (OpResult result : op->getResults()) {
+    if (failed(addShardOp(b, result, shardingOption,
+                          maps[numOperands + result.getResultNumber()],
+                          loopTypes)))
+      return failure();
+  }
+
+  // 2. add mesh.shard ops for all operands
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    if (failed(addShardOp(b, opOperand, shardingOption,
+                          maps[opOperand.getOperandNumber()], loopTypes)))
+      return failure();
+  }
+
+  return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..9f85d8e9cb22d5b
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRMeshTransforms
+  ShardingPropagation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRMeshPassIncGen
+  MLIRShardingInterface
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMeshDialect
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaShardingInterfaceImpl
+)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
new file mode 100644
index 000000000000000..a3f305444afd835
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -0,0 +1,155 @@
+//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+#include <vector>
+
+namespace mlir {
+namespace mesh {
+#define GEN_PASS_DEF_SHARDINGPROPAGATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+} // namespace mesh
+} // namespace mlir
+
+#define DEBUG_TYPE "sharding-propagation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static std::vector<Operation *> getOperationsVector(Block &block) {
+  std::vector<Operation *> res;
+  for (auto it = block.begin(); it != block.end(); ++it) {
+    Operation *op = &*it;
+    res.push_back(op);
+  }
+  return res;
+}
+
+static std::vector<Operation *> getReversedOperationsVector(Block &block) {
+  std::vector<Operation *> res;
+  for (auto it = block.rbegin(); it != block.rend(); ++it) {
+    Operation *op = &*it;
+    res.push_back(op);
+  }
+  return res;
+}
+
+// For each operation that implements the ShardingInterface, infer the sharding
+// option of the operation from its operands and/or results using the
+// `getShardingOption` method. If the inferred sharding option is not empty, add
+// a `mesh.shard` operation for all remaining operands and results that do not
+// have sharding annotations.
+LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+    return success();
+
+  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
+  if (!shardingOp) {
+    op->emitOpError() << "sharding interface is not implemented.";
+    return failure();
+  }
+
+  FailureOr<ShardingOption> shardingOption =
+      shardingOp.getShardingOption(builder);
+  if (failed(shardingOption)) {
+    op->emitOpError() << "fail to get sharding option from results.";
+    return failure();
+  }
+  // sharding info is empty, return immediately
+  if (shardingOption->empty)
+    return success();
+
+  ArrayAttr shardingArrayAttr =
+      builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray);
+  LLVM_DEBUG(DBGS() << "mesh cluster: " << shardingOption->cluster << "\n");
+  LLVM_DEBUG(DBGS() << "sharding array: " << shardingArrayAttr << "\n");
+  op->setAttr(getMeshClusterName(), shardingOption->cluster);
+  op->setAttr(getShardingArrayName(),
+              builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray));
+
+  if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
+    op->emitOpError() << "fail to set sharding annotations.";
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagationPass
+//===----------------------------------------------------------------------===//
+struct ShardingPropagationPass
+    : public mesh::impl::ShardingPropagationBase<ShardingPropagationPass> {
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    MLIRContext *ctx = funcOp.getContext();
+    Region &region = funcOp.getBody();
+    OpBuilder builder(ctx);
+    if (!region.hasOneBlock()) {
+      funcOp.emitOpError() << "only one block is supported!";
+      signalPassFailure();
+    }
+    Block &block = region.front();
+
+    // clang-format off
+    LLVM_DEBUG(
+      DBGS() << "print all the ops' iterator types and indexing maps in the "
+                "block.\n";
+      DenseSet<ShardingInterface> ops;
+      for (Operation &op : block.getOperations()) {
+        if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) {
+          ops.insert(shardingOp);
+        }
+      }
+      for (ShardingInterface shardingOp : ops) {
+        shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
+      }
+    );
+    // clang-format on
+
+    // 1. propagate in reversed order
+    {
+      std::vector<Operation *> curOps = getReversedOperationsVector(block);
+      for (Operation *op : curOps) {
+        if (failed(visitOp(op, builder)))
+          return signalPassFailure();
+      }
+    }
+
+    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+                      << funcOp << "\n");
+
+    // 2. propagate in original order
+    {
+      std::vector<Operation *> curOps = getOperationsVector(block);
+      for (Operation *op : curOps) {
+        if (failed(visitOp(op, builder)))
+          return signalPassFailure();
+      }
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::mesh::createShardingPropagationPass() {
+  return std::make_unique<ShardingPropagationPass>();
+}
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index 8e32579e0e4c2e3..ba5343dcd7ac6c1 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -26,4 +26,18 @@ add_mlir_dialect_library(MLIRTosaDialect
   MLIRViewLikeInterface
   )
 
+add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
+  IR/ShardingInterfaceImpl.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMeshDialect
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaDialect
+  )
+
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..dace86533c0e231
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -0,0 +1,111 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tosa-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tosa;
+using namespace mlir::mesh;
+
+namespace {
+
+template <typename ElemwiseOp>
+struct ElemwiseSharding
+    : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
+                                              ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+};
+
+// loop types: [parallel, parallel, parallel, reduction_sum]
+// indexing maps:
+// (d0, d1, d2, d3) -> (d0, d1, d3)
+// (d0, d1, d2, d3) -> (d0, d3, d2)
+// (d0, d1, d2, d3) -> (d0, d1, d2)
+struct MatMulOpSharding
+    : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+
+    SmallVector<IteratorType> types(tensorType.getRank() + 1,
+                                    IteratorType::Parallel);
+    types[tensorType.getRank()] = IteratorType::ReductionSum;
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+    MLIRContext *ctx = op->getContext();
+    SmallVector<AffineMap> maps;
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
+    return maps;
+  }
+};
+
+template <typename OpType>
+static void registerElemwiseOne(MLIRContext *ctx) {
+  OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerElemwiseAll(MLIRContext *ctx) {
+  (registerElemwiseOne<OpTypes>(ctx), ...);
+}
+
+} // namespace
+
+void mlir::tosa::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
+    registerElemwiseAll<
+        ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
+        BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
+        LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
+        MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
+        LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+        GreaterOp, GreaterEqualOp>(ctx);
+
+    MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index f4e29539214b4b6..b247b8cc694eca9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -259,6 +259,19 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
+SmallVector<SmallVector<int32_t>>
+mlir::getArrayOfI32Array(ArrayAttr arrayAttr) {
+  SmallVector<SmallVector<int32_t>> arrayOfI32Array;
+  for (auto attr : arrayAttr) {
+    arrayOfI32Array.push_back(llvm::to_vector(
+        llvm::map_range(llvm::cast<ArrayAttr>(attr), [&](Attribute intAttr) {
+          return static_cast<int32_t>(
+              llvm::cast<IntegerAttr>(intAttr).getInt());
+        })));
+  }
+  return arrayOfI32Array;
+}
+
 // TODO: do we have any common utily for this?
 static MLIRContext *getContext(OpFoldResult val) {
   assert(val && "Invalid value");
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9cdac964710ca86..db5087ce42809f8 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -235,6 +235,19 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   return permutationMap;
 }
 
+AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
+                                               ArrayRef<int64_t> targets,
+                                               MLIRContext *context) {
+  AffineMap result =
+      AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, context);
+  int64_t pos = 0;
+  for (int64_t t : targets) {
+    result = result.insertResult(getAffineDimExpr(t, context), pos);
+    pos += 1;
+  }
+  return result;
+}
+
 template <typename AffineExprContainer>
 static SmallVector<AffineMap, 4>
 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index ab20f4863e11c23..b86ee432a53af04 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -316,6 +316,15 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
   return getArrayAttr(attrs);
 }
 
+ArrayAttr
+Builder::getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values) {
+  auto attrs =
+      llvm::map_to_vector<8>(values, [this](ArrayRef<int32_t> v) -> Attribute {
+        return getI32ArrayAttr(v);
+      });
+  return getArrayAttr(attrs);
+}
+
 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
   auto attrs = llvm::map_to_vector<8>(
       values, [](Type v) -> Attribute { return TypeAttr::get(v); });
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
new file mode 100644
index 000000000000000..4c0809dc5e58636
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1)
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_3d(rank = 3)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT: tosa.sigmoid
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT: return
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
+  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V8:.*]] = tosa.negate %[[V7]]
+  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[V6]], %[[V9]]
+  return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %2 : tensor<2x16x32xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [0 : i32]]}
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [0 : i32]]}
+  %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array =  {{\[\[}}], [], [], [0 : i32]]}
+  %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  return %5 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [1 : i32, 2 : i32], [0 : i32]]}
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
+  // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [1 : i32, 2 : i32]]}
+  %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array =  {{\[\[}}], [], [0 : i32], [1 : i32, 2 : i32]]}
+  %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+  %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  return %6 : tensor<2x4x8xf32>
+}



More information about the Mlir-commits mailing list