[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)

Boian Petkantchin llvmlistbot at llvm.org
Fri Feb 2 16:24:36 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/80518

Add a pass that converts a function that has sharding annotations into SPMD form.

>From f210334f52556819ed2b7a7053fdb4b21358b0b7 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 26 Jan 2024 09:47:46 -0800
Subject: [PATCH] [mlir][mesh] Add spmdization pass

Add a pass that converts a function that has sharding annotations into SPMD
form.
---
 .../Func/Extensions/MeshShardingExtensions.h  |  23 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  43 +++
 .../Mesh/Interfaces/ShardingInterface.h       |  11 +-
 .../Mesh/Interfaces/ShardingInterface.td      |  47 ++-
 .../Mesh/Interfaces/ShardingInterfaceImpl.h   | 125 ++++++++
 .../mlir/Dialect/Mesh/Transforms/Passes.td    |  75 +++++
 .../Dialect/Mesh/Transforms/Spmdization.h     |  24 +-
 .../Dialect/Func/Extensions/AllExtensions.cpp |   2 +
 .../Dialect/Func/Extensions/CMakeLists.txt    |  15 +
 .../Extensions/MeshShardingExtensions.cpp     |  24 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  32 ++
 .../Mesh/Interfaces/ShardingInterface.cpp     |  65 +++-
 .../Mesh/Transforms/ShardingPropagation.cpp   |   6 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 291 +++++++++++++++---
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp |  30 +-
 mlir/test/Dialect/Mesh/spmdization.mlir       | 131 ++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   8 +-
 17 files changed, 855 insertions(+), 97 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
 create mode 100644 mlir/test/Dialect/Mesh/spmdization.mlir

diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
new file mode 100644
index 0000000000000..9b7abbca5d762
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 8e5e0f541ba5e..a819157fef147 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,11 +10,13 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 
 namespace mlir {
 namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+// Is the same tensor replicated on all processes.
+inline bool isFullReplication(MeshShardingAttr attr) {
+  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+}
+
 inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                             SymbolTableCollection &symbolTableCollection) {
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
   return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
 }
 
+template <>
+inline mesh::MeshOp
+getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+                 symbolTableCollection);
+}
+
 // Get the number of processes that participate in each group
 // induced by `meshAxes`.
 template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+// The shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// result in a shape for each shard of ?x2x?.
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
+                           MeshShardingAttr sharding);
+
+// If ranked tensor type return its sharded counterpart.
+//
+// If not ranked tensor type return `type`.
+// `sharding` in that case must be null.
+Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 3bef7e6babdec..cc90ddd40a622 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,11 +10,14 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
 class Operation;
+class IRMapping;
+class SymbolTableCollection;
 
 namespace mesh {
 
@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 
 } // namespace detail
 
-} // namespace mesh
+// Assumes full replication on all ranked tensor arguments and results.
+void spmdizeFullyReplicatedOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
 
+} // namespace mesh
 } // namespace mlir
 
 /// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 21b6c8d4f599a..4afb1c36a72f7 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           return detail::defaultAddShardingAnnotations(
             $_op.getOperation(), b, shardingOption);
         }]
-      >
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Convert self to SPMD form.
+          This method is used during the spmdization pass of a program fully
+          annotated with shardings.
+
+          The spmdization algorithm would read the surrounding sharding
+          annotations from the IR for each argument/result and prepare
+          `operandShardings` and `resultShardings`.
+          Values that are not ranked tensors do not have sharding annotations.
+          In this case their corresponding MeshShardingAttr is null.
+
+          For convenience it will also prepare `spmdizedOperands`, although
+          they can be retrieved from the `spmdizationMap`.
+
+          The `spmdizationMap` contains a mapping from unsharded to
+          sharded/spmdized values that are constructed during the spmdization
+          pass. The interface implementation must populate `spmdizationMap`
+          with the mapping for this op's results.
+
+          `builder` is set to insert new operations in the appropriate point.
+          The implementation should not return the builder to the original
+          insertion point.
+          It should leave it as is after all insertions are done.
+
+          The default implementation does full replication.
+          This assumes that all sharding annotations are for full replication.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"spmdize",
+        /*args=*/(ins
+          "ArrayRef<Value>": $spmdizedOperands,
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings,
+          "IRMapping&": $spmdizationMap,
+          "SymbolTableCollection &": $symbolTableCollection,
+          "OpBuilder &":$builder
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          spmdizeFullyReplicatedOperation(
+            *$_op.getOperation(), spmdizedOperands, operandShardings,
+              resultShardings, spmdizationMap, symbolTableCollection, builder);
+          return success();
+        }]>
     ];
 
     let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..8108386c2e043
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -0,0 +1,125 @@
+//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+
+class Operation;
+class IRMapping;
+class SymbolTableCollection;
+
+namespace mesh {
+
+// Inserts a clone of the operation that has all ranked tensor
+// arguments/results sharded.
+void spmdizeTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
+
+// All ranked tensor argument and result dimensions have
+// independent parallel loop iterators.
+template <typename Op>
+struct IndependentParallelIteratorDomainShardingInterface
+    : public ShardingInterface::ExternalModel<
+          IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<IteratorType> iterTypes;
+    for (Type t : operation->getOperandTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    for (Type t : operation->getResultTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    return iterTypes;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    // TODO: implement.
+    return SmallVector<AffineMap>();
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+
+private:
+  void populateIteratorTypes(Type t,
+                             SmallVector<IteratorType> &iterTypes) const {
+    RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+    if (!rankedTensorType) {
+      return;
+    }
+
+    iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
+    for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
+      iterTypes.push_back(IteratorType::Parallel);
+    }
+  }
+};
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+template <typename ElemwiseOp>
+struct ElementwiseShardingInterface
+    : public ShardingInterface::ExternalModel<
+          ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index c09cf3e710d42..3758b617bc14e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+  let summary = "Partition a function into SPMD form.";
+  let description = [{
+    This pass fits in right after a pass that annotates the function with
+    shardings like the `ShardingPropagation` pass.
+    It operates on a fully annotated IR.
+
+    A fully annotated IR required that all ranked tensor operands, results and
+    block arguments are annotated with the `mesh.shard` operation.
+  
+    All direct descendant operations in the function must implement the
+    `ShardingInterface` interface or all their ranked tensor operands and
+    results must have full replication sharding.
+
+    The input IR must have sharding annotations such that each operation
+    that implements `ShardingInterface` can handle during spmdization with
+    its `spmdize` method.
+    This can be achieved with the `ShardingPropagation` pass.
+
+    If the function has multiple terminating blocks,
+    it is the responsibility of the the one who annotates the function with
+    shardings to make sure that all returns would be consisted that is,
+    have the same sharding.
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+
+    func.func @f(
+      %arg0: tensor<2xi8>
+    // CHECK-SAME: -> tensor<2xi8> {
+    ) -> tensor<2xi8> {
+      %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+      return %4 : tensor<2xi8>
+    }
+    ```
+    Spmdizing the above would result in 
+    * resharding the fully replicated input into splitting it along the only
+    tensor axis.
+    * Performing the element-wise `abs` operation on each device.
+    * Resharding back to full replication with an all-gather.
+
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+    func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
+      // Reshard [[]] -> [[0]]
+      %c0 = arith.constant 0 : index
+      %c2 = arith.constant 2 : index
+      %0 = mesh.process_multi_index on @mesh_1d axes = [0] : index
+      %1 = mesh.mesh_shape @mesh_1d axes = [0] : index
+      %2 = arith.remui %c2, %1 : index
+      %3 = arith.cmpi eq, %2, %c0 : index
+      cf.assert %3, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
+      %4 = arith.divui %c2, %1 : index
+      %5 = arith.muli %4, %0 : index
+      %extracted_slice = tensor.extract_slice %arg0[%5] [1] [1] : tensor<2xi8> to tensor<1xi8>
+
+      %6 = tosa.abs %extracted_slice : (tensor<1xi8>) -> tensor<1xi8>
+
+      // Reshard [[0]] -> [[]]
+      %7 = mesh.all_gather %6 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+
+      return %7 : tensor<2xi8>
+    }
+    ```
+  }];
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index 7cb992aac019b..f847ce30a1b40 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -16,16 +16,32 @@
 namespace mlir {
 namespace mesh {
 
-// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding);
-
 // Insert resharding spmdization of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
+//
+// Example
+//
+// ```mlir
+//   mesh.mesh @mesh_1d(shape = 2)
+//   ...
+//   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+//   %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// ```
+//
+// Will result in
+//
+// ```mlir
+//   %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+//     tensor<1xi8> -> tensor<2xi8>
+// ```
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+                               ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue,
+                               SymbolTableCollection &symbolTableCollection);
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
 
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index b3a3a37663a21..eb6b59bb00f1b 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,9 +8,11 @@
 
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
 
 using namespace mlir;
 
 void mlir::func::registerAllExtensions(DialectRegistry &registry) {
   registerInlinerExtension(registry);
+  registerShardingInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 8712b1583719b..47363f48f95cc 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,6 +1,7 @@
 set(LLVM_OPTIONAL_SOURCES
   AllExtensions.cpp
   InlinerExtension.cpp
+  MeshShardingExtensions.cpp
   )
 
 add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -16,6 +17,19 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
   MLIRFuncDialect
   )
 
+add_mlir_extension_library(MLIRFuncMeshShardingExtensions
+  MeshShardingExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRShardingInterface
+  )
+
+
 add_mlir_extension_library(MLIRFuncAllExtensions
   AllExtensions.cpp
 
@@ -24,4 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
 
   LINK_LIBS PUBLIC
   MLIRFuncInlinerExtension
+  MLIRFuncMeshShardingExtensions
   )
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
new file mode 100644
index 0000000000000..da508cc95bfe1
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
@@ -0,0 +1,24 @@
+//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir::func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
+    ReturnOp::attachInterface<
+        mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
+        *ctx);
+  });
+}
+
+} // namespace mlir::func
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 7c6a5e02ed583..a65b8f2e5a237 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -175,6 +175,38 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
   }
 }
 
+template <typename InShape, typename MeshShape, typename SplitAxes,
+          typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+                       const SplitAxes &splitAxes, OutShape &outShape) {
+  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+            llvm::adl_begin(outShape));
+  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+  }
+}
+
+ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
+                                 MeshShardingAttr sharding) {
+  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+  SmallVector<Dim> resShapeArr(shape.getShape().size());
+  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+             resShapeArr);
+  return shape.clone(resShapeArr);
+}
+
+Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
+  RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
+  if (rankedTensorType) {
+    return shardShapedType(rankedTensorType, mesh, sharding);
+  }
+
+  assert(!sharding);
+  return type;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.mesh op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 5dc91ff1c02d2..a0a573f65a346 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -7,14 +7,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/Debug.h"
 
-#include <algorithm>
 #include <utility>
 
 #define DEBUG_TYPE "sharding-interface"
@@ -391,8 +393,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 // detail::defaultAddShardingAnnotations
 //===----------------------------------------------------------------------===//
 
-namespace {
-
 // To add a `mesh.shard` op for the given result, based on the details provided
 // in `shardingOption`, `map`, and `loopTypes`.
 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
@@ -493,8 +493,6 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
   return success();
 }
 
-} // namespace
-
 LogicalResult mesh::detail::defaultAddShardingAnnotations(
     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
@@ -519,3 +517,58 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
 
   return success();
 }
+
+static bool
+isValueCompatibleWithFullReplicationSharding(Value value,
+                                             MeshShardingAttr sharding) {
+  if (value.getType().isa<RankedTensorType>()) {
+    return sharding && isFullReplication(sharding);
+  }
+
+  return !sharding;
+}
+
+template <typename ValueRange, typename MeshShardingAttrRage>
+static bool areValuesCompatibleWithFullReplicationShardings(
+    ValueRange &&values, MeshShardingAttrRage &&shardings) {
+  if (std::size(values) != std::size(shardings)) {
+    return false;
+  }
+  return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
+                                std::forward<MeshShardingAttrRage>(shardings)),
+                      [](auto valueAndSharding) {
+                        return isValueCompatibleWithFullReplicationSharding(
+                            std::get<0>(valueAndSharding),
+                            std::get<1>(valueAndSharding));
+                      });
+}
+
+void mesh::spmdizeFullyReplicatedOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder) {
+  assert(spmdizedOperands.size() == operandShardings.size());
+  assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
+                                                         operandShardings));
+  assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
+                                                         resultShardings));
+  // `clone` will populate the mapping of old to new results.
+  builder.clone(op, spmdizationMap);
+}
+
+void mesh::spmdizeTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder) {
+  // `clone` will populate the mapping of old to new results.
+  Operation *newOp = builder.clone(op, spmdizationMap);
+  // Set the result types to the sharded counterparts.
+  for (auto [oldResult, newResult, sharding] :
+       llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
+    newResult.setType(shardType(newResult.getType(),
+                                getMesh(&op, sharding.getMesh(), symbolTable),
+                                sharding));
+  }
+}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index f3cd12f38879d..9f2647b21cbfc 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -29,8 +29,6 @@ namespace mesh {
 using namespace mlir;
 using namespace mlir::mesh;
 
-namespace {
-
 //===----------------------------------------------------------------------===//
 // Utilities
 //===----------------------------------------------------------------------===//
@@ -83,7 +81,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
 // `getShardingOption` method. If the inferred sharding option is not empty, add
 // a `mesh.shard` operation for all remaining operands and results that do not
 // have sharding annotations.
-LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
     return success();
 
@@ -207,5 +205,3 @@ struct ShardingPropagation
         return signalPassFailure();
   }
 };
-
-} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 5554edac4d2f6..509861af937ca 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -7,79 +7,40 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/MathExtras.h"
-#include "llvm/ADT/ADL.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include <algorithm>
+#include "llvm/Support/Casting.h"
 #include <iterator>
-#include <numeric>
 #include <optional>
 #include <tuple>
 #include <type_traits>
 
-namespace mlir {
-namespace mesh {
-
-int64_t shardDimension(int64_t dim, int64_t shardCount) {
-  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
-    return ShapedType::kDynamic;
-
-  assert(dim % shardCount == 0);
-  return ceilDiv(dim, shardCount);
-}
-
-int64_t unshardDimension(int64_t dim, int64_t shardCount) {
-  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
-    return ShapedType::kDynamic;
-
-  return dim * shardCount;
-}
-
-// Compute the shape for the tensor on each device in the mesh.
-// Example:
-// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
-// would result in a shape for each shard of ?x2x?.
-template <typename InShape, typename MeshShape, typename SplitAxes,
-          typename OutShape>
-static void shardShape(const InShape &inShape, const MeshShape &meshShape,
-                       const SplitAxes &splitAxes, OutShape &outShape) {
-  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
-            llvm::adl_begin(outShape));
-  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] = shardDimension(
-        inShape[tensorAxis],
-        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
-  }
-}
-
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding) {
-  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
-  SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr);
-  return shape.clone(resShapeArr);
-}
+namespace mlir::mesh {
 
 template <typename SourceAxes, typename TargetAxes>
 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
@@ -616,10 +577,240 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
       source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
 }
 
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+                               ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue,
+                               SymbolTableCollection &symbolTableCollection) {
+  MeshOp srcMesh = getMesh(source, symbolTableCollection);
+  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
+  return reshard(builder, srcMesh, source, target, sourceShardValue);
+}
+
 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
   registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
                   cf::ControlFlowDialect>();
 }
 
-} // namespace mesh
-} // namespace mlir
+#define GEN_PASS_DEF_SPMDIZATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+using UnshardedToShardedValueMap = DenseMap<Value, Value>;
+
+// Get the types of block arguments for an spmdized block.
+// Reads the sharding annotations of the arguments to deduce the sharded types.
+// Types that are not ranked tensors are left unchanged.
+SmallVector<Type>
+shardedBlockArgumentTypes(Block &block,
+                          SymbolTableCollection &symbolTableCollection) {
+  SmallVector<Type> res;
+  llvm::transform(block.getArguments(), std::back_inserter(res),
+                  [&symbolTableCollection](BlockArgument arg) {
+                    auto rankedTensorArg =
+                        arg.dyn_cast<TypedValue<RankedTensorType>>();
+                    if (!rankedTensorArg) {
+                      return arg.getType();
+                    }
+
+                    assert(rankedTensorArg.hasOneUse());
+                    Operation *useOp = *rankedTensorArg.getUsers().begin();
+                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+                    assert(shardOp);
+                    MeshOp mesh = getMesh(shardOp, symbolTableCollection);
+                    return shardShapedType(rankedTensorArg.getType(), mesh,
+                                           shardOp.getShardAttr())
+                        .cast<Type>();
+                  });
+  return res;
+}
+
+static LogicalResult spmdizeOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
+  if (!shardingInterface) {
+    // If there is no sharding interface we are conservative and assume that
+    // the op should be fully replicated no all devices.
+    spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
+                                    resultShardings, spmdizationMap,
+                                    symbolTableCollection, builder);
+  } else {
+    if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
+                                         resultShardings, spmdizationMap,
+                                         symbolTableCollection, builder))) {
+      return failure();
+    }
+  }
+
+  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
+    return spmdizationMap.contains(result);
+  }));
+
+  return success();
+}
+
+// Retrieve the sharding annotations for the operands of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
+  SmallVector<MeshShardingAttr> res;
+  res.reserve(op.getNumOperands());
+  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
+    TypedValue<RankedTensorType> rankedTensor =
+        operand.dyn_cast<TypedValue<RankedTensorType>>();
+    if (!rankedTensor) {
+      return MeshShardingAttr();
+    }
+
+    Operation *definingOp = operand.getDefiningOp();
+    assert(definingOp);
+    ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
+    assert(shardOp.getAnnotateForUsers());
+    return shardOp.getShard();
+  });
+  return res;
+}
+
+// Retrieve the sharding annotations for the results of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
+  SmallVector<MeshShardingAttr> res;
+  res.reserve(op.getNumResults());
+  llvm::transform(op.getResults(), std::back_inserter(res),
+                  [](OpResult result) {
+                    TypedValue<RankedTensorType> rankedTensor =
+                        result.dyn_cast<TypedValue<RankedTensorType>>();
+                    if (!rankedTensor) {
+                      return MeshShardingAttr();
+                    }
+
+                    assert(result.hasOneUse());
+                    Operation *userOp = *result.getUsers().begin();
+                    ShardOp shardOp = llvm::cast<ShardOp>(userOp);
+                    assert(!shardOp.getAnnotateForUsers());
+                    return shardOp.getShard();
+                  });
+  return res;
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+                 SymbolTableCollection &symbolTableCollection,
+                 OpBuilder &builder) {
+  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+  if (shardOp) {
+    if (!shardOp.getAnnotateForUsers()) {
+      return success();
+    }
+
+    // Insert resharding.
+    ShardOp srcShardOp =
+        llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
+    assert(!srcShardOp.getAnnotateForUsers());
+    TypedValue<ShapedType> srcSpmdValue =
+        spmdizationMap.lookup(srcShardOp.getOperand())
+            .cast<TypedValue<ShapedType>>();
+    Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+                                    symbolTableCollection);
+    assert(!spmdizationMap.contains(shardOp.getResult()));
+    spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+    return success();
+  }
+
+  SmallVector<Value> spmdizedOperands;
+  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
+                  [&spmdizationMap](Value operand) {
+                    assert(spmdizationMap.contains(operand));
+                    return spmdizationMap.lookup(operand);
+                  });
+  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
+                          getResultShardings(op), spmdizationMap,
+                          symbolTableCollection, builder);
+}
+
+static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
+                                  SymbolTableCollection &symbolTableCollection,
+                                  OpBuilder &builder) {
+  SmallVector<Location> argLocations;
+  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
+                  [](BlockArgument arg) { return arg.getLoc(); });
+  Block *newBlock = builder.createBlock(
+      block.getParent(), {},
+      shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
+  for (auto [unshardedBlockArg, spmdizedBlockArg] :
+       llvm::zip(block.getArguments(), newBlock->getArguments())) {
+    spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+  }
+
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointToEnd(newBlock);
+  for (Operation &op : block.getOperations()) {
+    if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
+                                builder))) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+static LogicalResult
+spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+              SymbolTableCollection &symbolTableCollection) {
+  OpBuilder builder(op.getFunctionBody());
+
+  // Snapshot the original blocks to not mess up the iteration when adding new
+  // blocks.
+  SmallVector<Block *> originalBlocks;
+  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
+                  [](Block &b) { return &b; });
+
+  for (Block *block : originalBlocks) {
+    if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
+                            builder))) {
+      return failure();
+    }
+  }
+
+  for (Block *block : originalBlocks) {
+    block->erase();
+  }
+
+  // Find a return op and change the function results signature to its operands
+  // signature.
+  func::ReturnOp returnOp;
+  for (Block &block : op.getBody()) {
+    if (block.empty()) {
+      continue;
+    }
+
+    returnOp = llvm::cast<func::ReturnOp>(block.back());
+    if (returnOp) {
+      break;
+    }
+  }
+  assert(returnOp);
+  op.setFunctionType(FunctionType::get(op->getContext(),
+                                       op.getBody().front().getArgumentTypes(),
+                                       returnOp->getOperandTypes()));
+
+  return success();
+}
+
+struct Spmdization : public impl::SpmdizationBase<Spmdization> {
+  void runOnOperation() override {
+    IRMapping spmdizationMap;
+    SymbolTableCollection symbolTableCollection;
+    if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
+                             symbolTableCollection))) {
+      return signalPassFailure();
+    }
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    reshardingRegisterDependentDialects(registry);
+    registry.insert<mesh::MeshDialect>();
+  }
+};
+
+} // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index dace86533c0e2..fefbca8c2e091 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/DialectRegistry.h"
@@ -23,33 +24,6 @@ using namespace mlir::mesh;
 
 namespace {
 
-template <typename ElemwiseOp>
-struct ElemwiseSharding
-    : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
-                                              ElemwiseOp> {
-  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
-    Value val = op->getOperand(0);
-    auto type = val.getType().dyn_cast<RankedTensorType>();
-    if (!type)
-      return {};
-    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
-    return types;
-  }
-
-  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
-    MLIRContext *ctx = op->getContext();
-    Value val = op->getOperand(0);
-    auto type = val.getType().dyn_cast<RankedTensorType>();
-    if (!type)
-      return {};
-    int64_t rank = type.getRank();
-    int64_t num = op->getNumOperands() + op->getNumResults();
-    SmallVector<AffineMap> maps(num,
-                                AffineMap::getMultiDimIdentityMap(rank, ctx));
-    return maps;
-  }
-};
-
 // loop types: [parallel, parallel, parallel, reduction_sum]
 // indexing maps:
 // (d0, d1, d2, d3) -> (d0, d1, d3)
@@ -83,7 +57,7 @@ struct MatMulOpSharding
 
 template <typename OpType>
 static void registerElemwiseOne(MLIRContext *ctx) {
-  OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+  OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
 }
 
 /// Variadic helper function.
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
new file mode 100644
index 0000000000000..9993c1518e9ea
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @full_replication
+func.func @full_replication(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[ARG]] : tensor<2xi8>
+  return %1 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
+  %arg0: tensor<2x2xi8>
+// CHECK-SAME: -> tensor<2x1xi8> {
+) -> tensor<2x2xi8> {
+  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
+  // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2x2xi8>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<2x2xi8>
+  // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
+  return %1 : tensor<2x2xi8>
+}
+
+// CHECK-LABEL: func @non_tensor_value
+func.func @non_tensor_value(
+  // CHECK-SAME: %[[ARG:.*]]: i8
+  %arg0: i8
+// CHECK-SAME: -> i8 {
+) -> i8 {
+  // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
+  %0 = arith.addi %arg0, %arg0 : i8
+  // CHECK: return %[[RES]] : i8
+  return %0 : i8
+}
+
+// CHECK-LABEL: func @unary_elementwise
+func.func @unary_elementwise(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %4 = mesh.shard %3 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %4 : tensor<2xi8>
+}
+
+// full replication -> shard axis -> abs -> shard axis -> full replication
+// CHECK-LABEL: func @unary_elementwise_with_resharding
+func.func @unary_elementwise_with_resharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+  // We don't care about the whole resharding IR, just that it happens.
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
+  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
+  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<2xi8>
+  return %4 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @binary_elementwise
+func.func @binary_elementwise(
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
+  %arg0: tensor<2xi8>,
+  // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
+  %arg1: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %op_arg0 = mesh.shard %arg0_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %arg1_sharded = mesh.shard %arg1 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %op_arg1 = mesh.shard %arg1_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
+  %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
+  %op_res_sharded = mesh.shard %op_res to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %res = mesh.shard %op_res_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %res : tensor<2xi8>
+}
+
+// reshard
+// abs
+// reshard
+// abs
+// reshard
+// CHECK-LABEL: func @multiple_chained_ops
+func.func @multiple_chained_ops(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  // We don't care about the whole resharding IR, just that it happens.
+  // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
+  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
+  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
+  %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
+  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
+  %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[RESHARD3]] : tensor<1xi8>
+  return %7 : tensor<2xi8>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ba2515050c978..3cc710ddcafc4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3252,7 +3252,10 @@ gentbl_cc_library(
 cc_library(
     name = "MeshShardingInterface",
     srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"],
-    hdrs = ["include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"],
+    hdrs = [
+        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h",
+        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h",
+    ],
     includes = ["include"],
     deps = [
         ":DialectUtils",
@@ -4548,8 +4551,9 @@ cc_library(
     deps = [
         ":ControlFlowDialect",
         ":FuncDialect",
-        ":IR",
         ":InferTypeOpInterface",
+        ":IR",
+        ":MeshShardingInterface",
     ],
 )
 



More information about the Mlir-commits mailing list