[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #71261)

Chengji Yao llvmlistbot at llvm.org
Fri Nov 3 18:02:02 PDT 2023


https://github.com/yaochengji created https://github.com/llvm/llvm-project/pull/71261

Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a mesh.shard operation.

The pass is driven by a newly added ShardingInterface, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.

>From c23087e06312b5ca4ec11ce8b782e15bcfcc6141 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Sat, 4 Nov 2023 00:54:34 +0000
Subject: [PATCH] [MLIR][Mesh] Add sharding propagation pass

---
 mlir/include/mlir/Dialect/Mesh/CMakeLists.txt |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  18 +
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |   4 +
 .../Mesh/Interfaces/ShardingInterface.h       |  68 +++
 .../Mesh/Interfaces/ShardingInterface.td      | 102 ++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   6 +
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |  39 ++
 .../mlir/Dialect/Mesh/Transforms/Passes.td    |  32 ++
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.h   |  23 +
 mlir/include/mlir/IR/AffineMap.h              |  12 +
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/include/mlir/InitAllPasses.h             |   2 +
 mlir/lib/Dialect/Mesh/CMakeLists.txt          |   2 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  32 +-
 .../Dialect/Mesh/Interfaces/CMakeLists.txt    |  15 +
 .../Mesh/Interfaces/ShardingInterface.cpp     | 524 ++++++++++++++++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |  19 +
 .../Mesh/Transforms/ShardingPropagation.cpp   | 210 +++++++
 mlir/lib/Dialect/Tosa/CMakeLists.txt          |  14 +
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 111 ++++
 mlir/lib/IR/AffineMap.cpp                     |  16 +-
 .../Dialect/Mesh/sharding-propagation.mlir    | 188 +++++++
 23 files changed, 1470 insertions(+), 5 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
 create mode 100644 mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..a91ef569347bff1 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"Parallel", 1, "parallel">,
+  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+  I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::mesh";
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
        $partial_axes^ `]`)? `>`
   }];
 
+  let builders = [
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
+                     "ArrayRef<int32_t>": $partial_axes,
+                     "mesh::Partial": $partial_type), [{
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
+                  split_axes, [&](ArrayRef<int32_t> array) {
+          return DenseI32ArrayAttr::get($_ctxt, array);
+      });
+      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+                   partial_type);
+    }]>,
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+    }]>
+  ];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..05eba66a89949b6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,22 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+  while (!array.empty() && array.back().empty())
+    array.pop_back();
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..d860628cf371aa9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,68 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+  // An array of int array. The sub-array at the i-th position signifies the
+  // mesh axes the i-th loop will be sharded on.
+  ShardingArray shardingArray;
+  SymbolRefAttr cluster;
+  // `empty` being true indicates that no sharding information can be inferred
+  // at present. Note that it is different from the case where an operation is
+  // not sharded.
+  bool empty = false;
+  ShardingOption() = default;
+  ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+      : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+};
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// result and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpResult result);
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// operand and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand);
+
+namespace detail {
+
+FailureOr<ShardingOption>
+defaultGetShardingOption(Operation *op,
+                         ArrayRef<MeshShardingAttr> operandShardings,
+                         ArrayRef<MeshShardingAttr> resultShardings);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+                              const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..21b6c8d4f599a8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,102 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+    let description = [{
+        Interface for allowing operations to expose information needed to
+        shard them.
+    }];
+    let cppNamespace = "::mlir::mesh";
+
+    let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of iterator types that describe the number of loops.
+          The iterator types determine how the operation traverses its input and
+          output tensors.
+
+          Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+          types are parallel, parallel, reduction-sum. This indicates that M and
+          N are traversed in parallel, while the K dimension is used for
+          reduction.
+
+          Example 2: A softmax op's loop iterator types are parallel and
+          invalid. The second dimension is considered as invalid because it is
+          neither parallel nor any kind of reduction. 
+        }],
+        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*methodName=*/"getLoopIteratorTypes",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the indexing maps attribute within the current operation.
+          Indexing maps determine how indices in the iteration space map to
+          tensor indices. They are specified using `affine_map` in MLIR, which
+          provides an affine transformation of indices.
+        }],
+        /*retTy=*/"SmallVector<AffineMap>",
+        /*methodName=*/"getIndexingMaps",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Given that certain operands or results of the operation may have
+          sharding annotations, this method leverages this information to deduce
+          how the operation should be sharded.
+        }],
+        /*retTy=*/"FailureOr<ShardingOption>",
+        /*methodName=*/"getShardingOption",
+        /*args=*/(ins
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingOption(
+            $_op.getOperation(), operandShardings, resultShardings);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, this method adds `mesh.shard`
+          operations for the operands and results that previously lacked
+          sharding annotations.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"addShardingAnnotations",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultAddShardingAnnotations(
+            $_op.getOperation(), b, shardingOption);
+        }]
+      >
+    ];
+
+    let extraClassDeclaration = [{
+      LogicalResult verifyShardingInterfaceImpl();
+
+      void printLoopTypesAndIndexingMaps(raw_ostream &os);
+    }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..83399d10beaae48
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,39 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..c09cf3e710d4278
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,32 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+  let summary = "sharding propagation";
+  let description = [{
+    Propagates sharding information throughout the graph. After this pass, each
+    of the operations' operands and results is annotated with a `mesh.shard`
+    operation, and the operations themselves are added with sharding option
+    attributes.
+  }];
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..f691a3daf8889c5 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,6 +104,18 @@ class AffineMap {
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
 
+  /// Returns an affine map with `numDims` input dimensions and results
+  /// specified by `targets`.
+  ///
+  /// Examples:
+  /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+  /// * getMultiDimMapWithTargets(3, [2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
+  static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+                                             ArrayRef<unsigned> targets,
+                                             MLIRContext *context);
+
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
   /// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 621110d130818d3..395d899f9ad84b0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   vector::registerSubsetOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 80894094484b999..f22980036ffcfa1 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -74,6 +75,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+  return (partial == Partial::Generic &&
+          iType == IteratorType::ReductionGeneric) ||
+         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+         (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+  switch (iType) {
+  case IteratorType::ReductionGeneric:
+    return Partial::Generic;
+  case IteratorType::ReductionSum:
+    return Partial::Sum;
+  case IteratorType::ReductionMax:
+    return Partial::Max;
+  case IteratorType::ReductionMin:
+    return Partial::Min;
+  default:
+    assert(0 && "No corresponding partial type can be found");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.cluster op
 //===----------------------------------------------------------------------===//
@@ -95,7 +126,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   }
   if (failed(checkMeshAxis(partialAxes)))
     return failure();
-
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..1010756f1fe279a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRShardingInterface
+  ShardingInterface.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRShardingInterfaceIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRIR
+  MLIRMeshDialect
+  MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
new file mode 100644
index 000000000000000..c2e1d1c726816a5
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -0,0 +1,524 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+                                  SmallVectorImpl<bool> &seenIds) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Add: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+      return failure();
+    if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+      return failure();
+    return success();
+  }
+  case AffineExprKind::Mul: {
+    auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+    AffineExpr lhs = binOpExpr.getLHS();
+    AffineExpr rhs = binOpExpr.getRHS();
+    AffineExpr dimExpr;
+    if (lhs.getKind() == AffineExprKind::DimId &&
+        rhs.getKind() == AffineExprKind::Constant) {
+      dimExpr = lhs;
+    } else if (rhs.getKind() == AffineExprKind::DimId &&
+               lhs.getKind() == AffineExprKind::Constant) {
+      dimExpr = rhs;
+    } else
+      return failure();
+    unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  case AffineExprKind::DimId: {
+    unsigned position = expr.cast<AffineDimExpr>().getPosition();
+    if ((size_t)position >= seenIds.size() || seenIds[position])
+      return failure();
+    seenIds[position] = true;
+    return success();
+  }
+  default:
+    return failure();
+  }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+  SmallVector<bool> seenIds(numDims, false);
+  if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+    return failure();
+
+  llvm::SmallSet<unsigned, 2> positions;
+  for (auto it : llvm::enumerate(seenIds)) {
+    if (it.value())
+      positions.insert((unsigned)it.index());
+  }
+  return positions;
+}
+
+//===----------------------------------------------------------------------===//
+// mesh::getMeshShardingAttr
+//===----------------------------------------------------------------------===//
+
+FailureOr<std::pair<bool, MeshShardingAttr>>
+mesh::getMeshShardingAttr(OpResult result) {
+  Value val = result.cast<Value>();
+  bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return !shardOp.getAnnotateForUsers();
+  });
+
+  if (anyShardedForDef) {
+    // expected to have exact one use if it has a use of `mesh.shard` without
+    // unit attr annotate_for_users
+    if (!val.hasOneUse())
+      return failure();
+    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+    return std::make_pair(false, shardOp.getShard());
+  }
+
+  bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    if (!shardOp)
+      return false;
+    return shardOp.getAnnotateForUsers();
+  });
+  if (anyShardedForUsers) {
+    SmallVector<ShardOp> shardOps;
+    for (Operation *user : val.getUsers()) {
+      ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+      if (shardOp)
+        shardOps.push_back(shardOp);
+    }
+    MeshShardingAttr shardForDef = shardOps[0].getShard();
+    for (size_t i = 1; i < shardOps.size(); ++i) {
+      // TODO: Deduce a reasonable mesh sharding attr for def when they are
+      // different
+      assert(shardOps[i].getShard() == shardForDef &&
+             "only support all shard ops have the same mesh sharding attr");
+    }
+    return std::make_pair(true, shardForDef);
+  }
+  return failure();
+}
+
+FailureOr<std::pair<bool, MeshShardingAttr>>
+mesh::getMeshShardingAttr(OpOperand &opOperand) {
+  Value val = opOperand.get();
+  if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
+    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+  Operation *op = getOperation();
+
+  // check operands and results type
+  for (Type type : op->getOperandTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+  for (Type type : op->getResultTypes())
+    if (!llvm::isa<RankedTensorType>(type))
+      return failure();
+
+  // check loop types
+  SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+  if (loopTypes.size() == 0)
+    return failure();
+
+  // check maps
+  SmallVector<AffineMap> maps = getIndexingMaps();
+  if (maps.size() == 0)
+    return failure();
+  unsigned numOperands = op->getNumOperands();
+  unsigned numResults = op->getNumResults();
+  if (numOperands + numResults != maps.size())
+    return failure();
+
+  for (OpResult result : op->getResults()) {
+    auto resultType = result.getType().dyn_cast<RankedTensorType>();
+    if (!resultType)
+      return failure();
+    AffineMap map = maps[numOperands + result.getResultNumber()];
+    if (!map.isProjectedPermutation()) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::printLoopTypesAndIndexingMaps
+//===----------------------------------------------------------------------===//
+
+void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+  os << "print loop types and indexing maps for: \n";
+  getOperation()->print(os);
+  os << "\n";
+  os << "loop types: [";
+  for (IteratorType type : getLoopIteratorTypes()) {
+    os << stringifyEnum(type) << " ";
+  }
+  os << "]\n";
+  os << "indexing maps: \n";
+  for (AffineMap map : getIndexingMaps())
+    os << map << "\n";
+  os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultGetShardingOption
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+static LogicalResult fillShardingOption(Operation *op,
+                                        ShardingOption &shardingOption,
+                                        SymbolRefAttr cluster,
+                                        ArrayRef<int32_t> meshAxes,
+                                        unsigned loopIdx) {
+  if ((shardingOption.cluster && cluster &&
+       shardingOption.cluster != cluster) ||
+      (!shardingOption.shardingArray[loopIdx].empty() &&
+       shardingOption.shardingArray[loopIdx] != meshAxes)) {
+    LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
+                      << loopIdx << "\n");
+    return failure();
+  }
+  for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
+    if (i == loopIdx)
+      continue;
+
+    for (int32_t axis : meshAxes) {
+      if (std::find(shardingOption.shardingArray[i].begin(),
+                    shardingOption.shardingArray[i].end(),
+                    axis) != shardingOption.shardingArray[i].end()) {
+        LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
+                          << axis << " duplicate");
+        return failure();
+      }
+    }
+  }
+  if (cluster)
+    shardingOption.cluster = cluster;
+  if (shardingOption.shardingArray[loopIdx].empty())
+    shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
+                                                 meshAxes.end());
+  return success();
+}
+
+} // namespace
+
+FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
+    Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings) {
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  ShardingOption shardingOption;
+
+  if (failed(shardingOp.verifyShardingInterfaceImpl()))
+    return op->emitOpError() << "invalid sharding interface implementation";
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+  shardingOption.shardingArray.resize(loopTypes.size());
+  llvm::SmallVector<int32_t> partialMeshAxes;
+  Partial partialType;
+  llvm::SmallSet<unsigned, 4> visitedLoopIndices;
+  bool anyShardingInResultsOrOperands = false;
+
+  // 1. Fill sharding option based on op results
+  for (auto shardingIt : llvm::enumerate(resultShardings)) {
+    MeshShardingAttr shardAttr = shardingIt.value();
+    if (!shardAttr)
+      continue;
+    AffineMap map = maps[numOperands + shardingIt.index()];
+    anyShardingInResultsOrOperands = true;
+    // Handle the split axes: calculate the corresponding loop index for each
+    // split axes sub-array, and then store the sub-array to
+    // shardingOption[index]
+    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      auto dim = expr.cast<AffineDimExpr>();
+      unsigned index = dim.getPosition();
+      visitedLoopIndices.insert(index);
+      if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
+                                    axes, index)))
+        return failure();
+    }
+
+    // Handle the partial axes: at this stage, the exact loop index/indices
+    // cannot be decided because there could be multiple reduction loops.
+    ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
+    if (!partialAxes.empty()) {
+      if (!partialMeshAxes.empty())
+        return op->emitOpError() << "at most one result with partial axes is "
+                                    "supported at present";
+      partialType = shardAttr.getPartialType();
+      partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
+      // Add all the reduction loop indices to `visitedLoopIndices` if
+      // `partialAxes` is not empty
+      for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
+        if (isReductionLoop(loopTypes[loopIdx]))
+          visitedLoopIndices.insert(loopIdx);
+      }
+    }
+  }
+
+  // 2. Fill sharding option based on operands
+  for (auto shardingIt : llvm::enumerate(operandShardings)) {
+    MeshShardingAttr shardAttr = shardingIt.value();
+    if (!shardAttr)
+      continue;
+
+    anyShardingInResultsOrOperands = true;
+    AffineMap map = maps[shardingIt.index()];
+    unsigned numDims = map.getNumDims();
+
+    // Handle the split axes. Partial axes don't need to be handled because they
+    // only affect the defining op of the operand.
+    //
+    // TODO: Change to process the operands with single loop index first and
+    // then the operands with multiple loop indices.
+    for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+      AffineExpr expr = std::get<0>(it);
+      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+          checkOperandAffineExpr(expr, numDims);
+      if (failed(loopIndices))
+        return op->emitOpError()
+               << "operand's affine expression is restricted to const_i * "
+                  "dim_i + const_j + dim_j + ...";
+      if (loopIndices->empty())
+        continue;
+      if (loopIndices->size() == 1) {
+        unsigned loopIdx = *loopIndices->begin();
+        visitedLoopIndices.insert(loopIdx);
+        if (failed(fillShardingOption(op, shardingOption,
+                                      shardAttr.getCluster(), axes, loopIdx)))
+          return failure();
+      }
+      // If multiple loop indices correspond to a dimension of an operand, it is
+      // difficult to infer which loop indices are responsible for sharding.
+      // Therefore, the exact loop index must be specified by others.
+      if (loopIndices->size() > 1) {
+        bool seenLoopIndices = false;
+        for (unsigned loopIdx : *loopIndices) {
+          if (visitedLoopIndices.contains(loopIdx)) {
+            seenLoopIndices = true;
+            break;
+          }
+        }
+        if (!seenLoopIndices)
+          return op->emitOpError()
+                 << "the operand " << shardingIt.index()
+                 << " has multiple loop indices in a dimension, but none of "
+                    "them could be found in the exactly specified annotation "
+                    "of op results or operands.";
+      }
+    }
+  }
+
+  // 3. Finalize sharding option
+  if (!partialMeshAxes.empty()) {
+    bool anyNonEmptyReductionLoop = llvm::any_of(
+        llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
+          SmallVector<int32_t> &subArray = it.value();
+          int64_t idx = it.index();
+          return isReductionLoop(loopTypes[idx]) && !subArray.empty();
+        });
+    if (!anyNonEmptyReductionLoop) {
+      bool filled = false;
+      for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
+        if (isReductionLoop(loopTypes[idx]) &&
+            areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+          std::ignore = fillShardingOption(op, shardingOption, nullptr,
+                                           partialMeshAxes, idx);
+          filled = true;
+          break;
+        }
+      }
+      if (!filled)
+        return op->emitOpError() << "no matched reduction loop found for the "
+                                    "result's partial type";
+    }
+  }
+  removeTrailingEmptySubArray(shardingOption.shardingArray);
+  if (!anyShardingInResultsOrOperands)
+    shardingOption.empty = true;
+  return shardingOption;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
+      getMeshShardingAttr(result);
+  if (succeeded(maybeSharding) && !maybeSharding->first)
+    return success();
+
+  auto resultType = result.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
+  SmallVector<int32_t> partialAxes;
+
+  // process the split axes
+  for (auto it : llvm::enumerate(map.getResults())) {
+    AffineExpr expr = it.value();
+    // `expr` must be an `AffineDimExpr` because `map` is verified by
+    // isProjectedPermutation
+    auto dim = expr.cast<AffineDimExpr>();
+    unsigned loopIdx = dim.getPosition();
+    if (loopIdx < shardingOption.shardingArray.size())
+      splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
+  }
+
+  // process the partial axes
+  Partial partialType;
+  for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
+    IteratorType iType = std::get<0>(it);
+    if (isReductionLoop(iType)) {
+      Partial curPartialType = getPartialTypeFromReduction(iType);
+      if (!partialAxes.empty())
+        assert(partialType == curPartialType &&
+               "Only one reduction type is supported");
+      partialType = curPartialType;
+      const SmallVector<int32_t> &axis = std::get<1>(it);
+      partialAxes.append(axis);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
+                            partialAxes, partialType);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointAfterValue(result);
+  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
+                                   shardAttr, /*annotate_for_users*/ false);
+  result.replaceAllUsesExcept(shardOp, shardOp);
+  return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+                                const ShardingOption &shardingOption,
+                                AffineMap map,
+                                ArrayRef<IteratorType> loopTypes) {
+  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
+  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
+    return success();
+  Value operand = opOperand.get();
+  auto operandType = operand.getType().cast<RankedTensorType>();
+  SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+  unsigned numDims = map.getNumDims();
+  for (auto it : llvm::enumerate(map.getResults())) {
+    int64_t idx = it.index();
+    AffineExpr expr = it.value();
+    FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+        checkOperandAffineExpr(expr, numDims);
+    if (failed(loopIndices))
+      return failure();
+    SmallVector<unsigned> shardedLoopIndices;
+    for (unsigned loopIdx : *loopIndices) {
+      if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
+          !shardingOption.shardingArray[loopIdx].empty())
+        shardedLoopIndices.push_back(loopIdx);
+    }
+    // mostly one sharded loop index is accepted
+    if (shardedLoopIndices.size() > 1)
+      return failure();
+    if (shardedLoopIndices.size() == 1) {
+      splitAxes[idx].append(
+          shardingOption.shardingArray[shardedLoopIndices[0]]);
+    }
+  }
+
+  removeTrailingEmptySubArray(splitAxes);
+  MeshShardingAttr shardAttr =
+      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(opOperand.getOwner());
+  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
+                                   shardAttr, true);
+  opOperand.set(shardOp);
+
+  return success();
+}
+
+} // namespace
+
+LogicalResult mesh::detail::defaultAddShardingAnnotations(
+    Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+  ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+  SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+  SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+  unsigned numOperands = op->getNumOperands();
+
+  // 1. add mesh.shard ops for all op results
+  for (OpResult result : op->getResults()) {
+    if (failed(addShardOp(b, result, shardingOption,
+                          maps[numOperands + result.getResultNumber()],
+                          loopTypes)))
+      return failure();
+  }
+
+  // 2. add mesh.shard ops for all operands
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    if (failed(addShardOp(b, opOperand, shardingOption,
+                          maps[opOperand.getOperandNumber()], loopTypes)))
+      return failure();
+  }
+
+  return success();
+}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..bcf45c4ea276080
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRMeshTransforms
+  ShardingPropagation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRMeshPassIncGen
+  MLIRShardingInterface
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRMeshDialect
+  MLIRPass
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaShardingInterfaceImpl
+)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
new file mode 100644
index 000000000000000..3aed912fb43c63e
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -0,0 +1,210 @@
+//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+#include <vector>
+
+namespace mlir {
+namespace mesh {
+#define GEN_PASS_DEF_SHARDINGPROPAGATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+} // namespace mesh
+} // namespace mlir
+
+#define DEBUG_TYPE "sharding-propagation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+// This method retrieves all potential sharding attributes, prioritizing
+// specific shardings. For example, mustShardings = [shard0, None] and
+// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
+// [shard0, None]]
+static SmallVector<SmallVector<MeshShardingAttr>>
+getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
+                                ArrayRef<MeshShardingAttr> optionalShardings) {
+  SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
+  SmallVector<MeshShardingAttr> curShardingAttrs;
+
+  std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
+    if (i == mustShardings.size()) {
+      allShardingAttrs.push_back(
+          SmallVector<MeshShardingAttr>(curShardingAttrs));
+      return;
+    }
+
+    if (mustShardings[i]) {
+      curShardingAttrs.push_back(mustShardings[i]);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      return;
+    }
+
+    if (optionalShardings[i]) {
+      curShardingAttrs.push_back(optionalShardings[i]);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      curShardingAttrs.push_back(nullptr);
+      dfsCreateShardingAttrs(i + 1);
+      curShardingAttrs.pop_back();
+      return;
+    }
+
+    curShardingAttrs.push_back(nullptr);
+    dfsCreateShardingAttrs(i + 1);
+    curShardingAttrs.pop_back();
+  };
+
+  dfsCreateShardingAttrs(0);
+  return allShardingAttrs;
+}
+
+// For each operation that implements the ShardingInterface, infer the sharding
+// option of the operation from its operands and/or results using the
+// `getShardingOption` method. If the inferred sharding option is not empty, add
+// a `mesh.shard` operation for all remaining operands and results that do not
+// have sharding annotations.
+LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+    return success();
+
+  ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
+  if (!shardingOp) {
+    op->emitOpError() << "sharding interface is not implemented.";
+    return failure();
+  }
+
+  // collect MeshShardingAttr from results
+  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
+  allowConflictsResultShardings.resize(op->getNumResults());
+  SmallVector<MeshShardingAttr> resultMustShardings;
+  resultMustShardings.resize(op->getNumResults());
+  for (OpResult result : op->getResults()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(result);
+    if (failed(maybeShardAttr))
+      continue;
+    if (!maybeShardAttr->first)
+      resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
+    else
+      allowConflictsResultShardings[result.getResultNumber()] =
+          maybeShardAttr->second;
+  }
+
+  // collect MeshShardingAttr from operands
+  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+  allowConflictsOperandShardings.resize(op->getNumOperands());
+  SmallVector<MeshShardingAttr> operandMustShardings;
+  operandMustShardings.resize(op->getNumOperands());
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(opOperand);
+    if (failed(maybeShardAttr))
+      continue;
+
+    if (maybeShardAttr->first)
+      operandMustShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+    else
+      allowConflictsOperandShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+  }
+
+  // try to get the sharding option
+  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+      getOrderedPossibleShardingAttrs(operandMustShardings,
+                                      allowConflictsOperandShardings);
+  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
+      getOrderedPossibleShardingAttrs(resultMustShardings,
+                                      allowConflictsResultShardings);
+  FailureOr<ShardingOption> finalShardingOption = failure();
+  for (ArrayRef<MeshShardingAttr> resultShardings :
+       possibleResultShardingAttrs) {
+    if (succeeded(finalShardingOption))
+      break;
+    for (ArrayRef<MeshShardingAttr> operandShardings :
+         possibleOperandShardingAttrs) {
+      FailureOr<ShardingOption> shardingOption =
+          shardingOp.getShardingOption(operandShardings, resultShardings);
+      if (succeeded(shardingOption)) {
+        finalShardingOption = shardingOption;
+        break;
+      }
+    }
+  }
+
+  if (failed(finalShardingOption)) {
+    op->emitOpError() << "fail to get sharding option.";
+    return failure();
+  }
+  // sharding info is empty, return immediately
+  if (finalShardingOption->empty)
+    return success();
+
+  if (failed(
+          shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
+    op->emitOpError() << "fail to set sharding annotations.";
+    return failure();
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+struct ShardingPropagation
+    : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    MLIRContext *ctx = funcOp.getContext();
+    Region &region = funcOp.getBody();
+    OpBuilder builder(ctx);
+    if (!region.hasOneBlock()) {
+      funcOp.emitOpError() << "only one block is supported!";
+      signalPassFailure();
+    }
+    Block &block = region.front();
+
+    LLVM_DEBUG(
+        DBGS() << "print all the ops' iterator types and indexing maps in the "
+                  "block.\n";
+        for (Operation &op
+             : block.getOperations()) {
+          if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
+            shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
+        });
+
+    // 1. propagate in reversed order
+    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
+
+    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+                      << funcOp << "\n");
+
+    // 2. propagate in original order
+    for (Operation &op : llvm::make_early_inc_range(block))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index 8e32579e0e4c2e3..ba5343dcd7ac6c1 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -26,4 +26,18 @@ add_mlir_dialect_library(MLIRTosaDialect
   MLIRViewLikeInterface
   )
 
+add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
+  IR/ShardingInterfaceImpl.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMeshDialect
+  MLIRShardingInterface
+  MLIRSupport
+  MLIRTosaDialect
+  )
+
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..dace86533c0e231
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -0,0 +1,111 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tosa-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tosa;
+using namespace mlir::mesh;
+
+namespace {
+
+template <typename ElemwiseOp>
+struct ElemwiseSharding
+    : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
+                                              ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+};
+
+// loop types: [parallel, parallel, parallel, reduction_sum]
+// indexing maps:
+// (d0, d1, d2, d3) -> (d0, d1, d3)
+// (d0, d1, d2, d3) -> (d0, d3, d2)
+// (d0, d1, d2, d3) -> (d0, d1, d2)
+struct MatMulOpSharding
+    : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+
+    SmallVector<IteratorType> types(tensorType.getRank() + 1,
+                                    IteratorType::Parallel);
+    types[tensorType.getRank()] = IteratorType::ReductionSum;
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      return {};
+    MLIRContext *ctx = op->getContext();
+    SmallVector<AffineMap> maps;
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+    maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
+    return maps;
+  }
+};
+
+template <typename OpType>
+static void registerElemwiseOne(MLIRContext *ctx) {
+  OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerElemwiseAll(MLIRContext *ctx) {
+  (registerElemwiseOne<OpTypes>(ctx), ...);
+}
+
+} // namespace
+
+void mlir::tosa::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
+    registerElemwiseAll<
+        ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
+        BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
+        LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
+        MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
+        LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+        GreaterOp, GreaterEqualOp>(ctx);
+
+    MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+  });
+}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..cdcd71cdd7cd151 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -227,15 +227,23 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
                                        MLIRContext *context) {
   assert(!permutation.empty() &&
          "Cannot create permutation map from empty permutation vector");
-  SmallVector<AffineExpr, 4> affExprs;
-  for (auto index : permutation)
-    affExprs.push_back(getAffineDimExpr(index, context));
   const auto *m = std::max_element(permutation.begin(), permutation.end());
-  auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
+  auto permutationMap = getMultiDimMapWithTargets(*m + 1, permutation, context);
   assert(permutationMap.isPermutation() && "Invalid permutation vector");
   return permutationMap;
 }
 
+AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
+                                               ArrayRef<unsigned> targets,
+                                               MLIRContext *context) {
+  SmallVector<AffineExpr, 4> affExprs;
+  for (unsigned t : targets)
+    affExprs.push_back(getAffineDimExpr(t, context));
+  AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0,
+                                    affExprs, context);
+  return result;
+}
+
 template <typename AffineExprContainer>
 static SmallVector<AffineMap, 4>
 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
new file mode 100644
index 000000000000000..bda407b52bfd4f2
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -0,0 +1,188 @@
+// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1)
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_3d(rank = 3)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT: tosa.sigmoid
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT: return
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
+  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V8:.*]] = tosa.negate %[[V7]]
+  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[V6]], %[[V9]]
+  return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+  %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %2 : tensor<2x16x32xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+  // CHECK: %[[V0:.*]] = tosa.matmul
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
+  %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V6:.*]] = mesh.shard %[[ARG2]] to <@mesh_1d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
+  %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_1d, {{\[\[}}], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+  %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: return %[[V9]]
+  return %5 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+  // CHECK-DAG: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> : tensor<2x4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_3d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[ARG1]] to <@mesh_3d, {{\[\[}}], [0], [1, 2]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
+  %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_3d,  {{\[\[}}], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
+  %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
+  %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V7:.*]] = mesh.shard %[[V6]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[ARG2]] to <@mesh_3d, {{\[\[}}], [1, 2], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
+  %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V11:.*]] = mesh.shard %[[V10]] to <@mesh_3d, {{\[\[}}], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+  %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[V12:.*]] = mesh.shard %[[V11]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: return %[[V12]]
+  return %6 : tensor<2x4x8xf32>
+}



More information about the Mlir-commits mailing list