[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 2 16:25:07 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

Add a pass that converts a function that has sharding annotations into SPMD form.

---

Patch is 49.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80518.diff


17 Files Affected:

- (added) mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h (+23) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+43) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+10-1) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+46-1) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+125) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+75) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+20-4) 
- (modified) mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp (+2) 
- (modified) mlir/lib/Dialect/Func/Extensions/CMakeLists.txt (+15) 
- (added) mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp (+24) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+32) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+59-6) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+1-5) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+241-50) 
- (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+2-28) 
- (added) mlir/test/Dialect/Mesh/spmdization.mlir (+131) 
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+6-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
new file mode 100644
index 0000000000000..9b7abbca5d762
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 8e5e0f541ba5e..a819157fef147 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,11 +10,13 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 
 namespace mlir {
 namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+// Is the same tensor replicated on all processes.
+inline bool isFullReplication(MeshShardingAttr attr) {
+  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+}
+
 inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                             SymbolTableCollection &symbolTableCollection) {
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
   return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
 }
 
+template <>
+inline mesh::MeshOp
+getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+                 symbolTableCollection);
+}
+
 // Get the number of processes that participate in each group
 // induced by `meshAxes`.
 template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+// The shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// result in a shape for each shard of ?x2x?.
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
+                           MeshShardingAttr sharding);
+
+// If ranked tensor type return its sharded counterpart.
+//
+// If not ranked tensor type return `type`.
+// `sharding` in that case must be null.
+Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 3bef7e6babdec..cc90ddd40a622 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,11 +10,14 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
 class Operation;
+class IRMapping;
+class SymbolTableCollection;
 
 namespace mesh {
 
@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 
 } // namespace detail
 
-} // namespace mesh
+// Assumes full replication on all ranked tensor arguments and results.
+void spmdizeFullyReplicatedOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
 
+} // namespace mesh
 } // namespace mlir
 
 /// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 21b6c8d4f599a..4afb1c36a72f7 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           return detail::defaultAddShardingAnnotations(
             $_op.getOperation(), b, shardingOption);
         }]
-      >
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Convert self to SPMD form.
+          This method is used during the spmdization pass of a program fully
+          annotated with shardings.
+
+          The spmdization algorithm would read the surrounding sharding
+          annotations from the IR for each argument/result and prepare
+          `operandShardings` and `resultShardings`.
+          Values that are not ranked tensors do not have sharding annotations.
+          In this case their corresponding MeshShardingAttr is null.
+
+          For convenience it will also prepare `spmdizedOperands`, although
+          they can be retrieved from the `spmdizationMap`.
+
+          The `spmdizationMap` contains a mapping from unsharded to
+          sharded/spmdized values that are constructed during the spmdization
+          pass. The interface implementation must populate `spmdizationMap`
+          with the mapping for this op's results.
+
+          `builder` is set to insert new operations in the appropriate point.
+          The implementation should not return the builder to the original
+          insertion point.
+          It should leave it as is after all insertions are done.
+
+          The default implementation does full replication.
+          This assumes that all sharding annotations are for full replication.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"spmdize",
+        /*args=*/(ins
+          "ArrayRef<Value>": $spmdizedOperands,
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings,
+          "IRMapping&": $spmdizationMap,
+          "SymbolTableCollection &": $symbolTableCollection,
+          "OpBuilder &":$builder
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          spmdizeFullyReplicatedOperation(
+            *$_op.getOperation(), spmdizedOperands, operandShardings,
+              resultShardings, spmdizationMap, symbolTableCollection, builder);
+          return success();
+        }]>
     ];
 
     let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..8108386c2e043
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -0,0 +1,125 @@
+//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+
+class Operation;
+class IRMapping;
+class SymbolTableCollection;
+
+namespace mesh {
+
+// Inserts a clone of the operation that has all ranked tensor
+// arguments/results sharded.
+void spmdizeTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
+
+// All ranked tensor argument and result dimensions have
+// independent parallel loop iterators.
+template <typename Op>
+struct IndependentParallelIteratorDomainShardingInterface
+    : public ShardingInterface::ExternalModel<
+          IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<IteratorType> iterTypes;
+    for (Type t : operation->getOperandTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    for (Type t : operation->getResultTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    return iterTypes;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    // TODO: implement.
+    return SmallVector<AffineMap>();
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+
+private:
+  void populateIteratorTypes(Type t,
+                             SmallVector<IteratorType> &iterTypes) const {
+    RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+    if (!rankedTensorType) {
+      return;
+    }
+
+    iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
+    for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
+      iterTypes.push_back(IteratorType::Parallel);
+    }
+  }
+};
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+template <typename ElemwiseOp>
+struct ElementwiseShardingInterface
+    : public ShardingInterface::ExternalModel<
+          ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index c09cf3e710d42..3758b617bc14e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+  let summary = "Partition a function into SPMD form.";
+  let description = [{
+    This pass fits in right after a pass that annotates the function with
+    shardings like the `ShardingPropagation` pass.
+    It operates on a fully annotated IR.
+
+    A fully annotated IR required that all ranked tensor operands, results and
+    block arguments are annotated with the `mesh.shard` operation.
+  
+    All direct descendant operations in the function must implement the
+    `ShardingInterface` interface or all their ranked tensor operands and
+    results must have full replication sharding.
+
+    The input IR must have sharding annotations such that each operation
+    that implements `ShardingInterface` can handle during spmdization with
+    its `spmdize` method.
+    This can be achieved with the `ShardingPropagation` pass.
+
+    If the function has multiple terminating blocks,
+    it is the responsibility of the the one who annotates the function with
+    shardings to make sure that all returns would be consisted that is,
+    have the same sharding.
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+
+    func.func @f(
+      %arg0: tensor<2xi8>
+    // CHECK-SAME: -> tensor<2xi8> {
+    ) -> tensor<2xi8> {
+      %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+      return %4 : tensor<2xi8>
+    }
+    ```
+    Spmdizing the above would result in 
+    * resharding the fully replicated input into splitting it along the only
+    tensor axis.
+    * Performing the element-wise `abs` operation on each device.
+    * Resharding back to full replication with an all-gather.
+
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+    func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
+      // Reshard [[]] -> [[0]]
+      %c0 = arith.constant 0 : index
+      %c2 = arith.constant 2 : index
+      %0 = mesh.process_multi_index on @mesh_1d axes = [0] : index
+      %1 = mesh.mesh_shape @mesh_1d axes = [0] : index
+      %2 = arith.remui %c2, %1 : index
+      %3 = arith.cmpi eq, %2, %c0 : index
+      cf.assert %3, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
+      %4 = arith.divui %c2, %1 : index
+      %5 = arith.muli %4, %0 : index
+      %extracted_slice = tensor.extract_slice %arg0[%5] [1] [1] : tensor<2xi8> to tensor<1xi8>
+
+      %6 = tosa.abs %extracted_slice : (tensor<1xi8>) -> tensor<1xi8>
+
+      // Reshard [[0]] -> [[]]
+      %7 = mesh.all_gather %6 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+
+      return %7 : tensor<2xi8>
+    }
+    ```
+  }];
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index 7cb992aac019b..f847ce30a1b40 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -16,16 +16,32 @@
 namespace mlir {
 namespace mesh {
 
-// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding);
-
 // Insert resharding spmdization of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
+//
+// Example
+//
+// ```mlir
+//   mesh.mesh @mesh_1d(shape = 2)
+//   ...
+//   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+//   %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// ```
+//
+// Will result in
+//
+// ```mlir
+//   %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+//     tensor<1xi8> -> tensor<2xi8>
+// ```
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+                               ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue,
+                               SymbolTableCollection &symbolTableCollection);
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
 
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index b3a3a37663a21..eb6b59bb00f1b 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,9 +8,11 @@
 
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
 
 using namespace mlir;
 
 void mlir::func::registerAllExtensions(DialectRegistry &registry) {
   registerInlinerExtension(registry);
+  registerShardingInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 8712b1583719b..47363f48f95cc 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,6 +1,7 @@
 set(LLVM_OPTIONAL_SOURCES
   AllExtensions.cpp
   InlinerExtension.cpp
+  MeshShardingExtensions.cpp
   )
 
 add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -16,6 +17,19 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
   MLIRFuncDialect
   )
 
+add_mlir_extension_library(MLIRFuncMeshShardingExtensions
+  MeshShardingExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRShardingInterface
+  )
+
+
 add_mlir_extension_library(MLIRFuncAllExtensions
   AllExtensions.cpp
 
@@ -24,4 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
 
   LINK_LIBS PUBLIC
   MLIRFuncInlinerExtension
+  MLIRFuncMeshShardingExtensions
   )
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
new file mode 100644
index 0000000000000..da508cc95bfe1
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
@@ -0,0 +1,24 @@
+//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/80518


More information about the Mlir-commits mailing list