[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 2 16:25:07 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Add a pass that converts a function that has sharding annotations into SPMD form.
---
Patch is 49.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80518.diff
17 Files Affected:
- (added) mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h (+23)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+43)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+10-1)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+46-1)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+125)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+75)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+20-4)
- (modified) mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp (+2)
- (modified) mlir/lib/Dialect/Func/Extensions/CMakeLists.txt (+15)
- (added) mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp (+24)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+32)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+59-6)
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+1-5)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+241-50)
- (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+2-28)
- (added) mlir/test/Dialect/Mesh/spmdization.mlir (+131)
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+6-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
new file mode 100644
index 0000000000000..9b7abbca5d762
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 8e5e0f541ba5e..a819157fef147 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,11 +10,13 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/MathExtras.h"
namespace mlir {
namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
Partial getPartialTypeFromReduction(IteratorType iType);
+// Is the same tensor replicated on all processes.
+inline bool isFullReplication(MeshShardingAttr attr) {
+ return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+}
+
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
}
+template <>
+inline mesh::MeshOp
+getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+ return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+ symbolTableCollection);
+}
+
// Get the number of processes that participate in each group
// induced by `meshAxes`.
template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ assert(dim % shardCount == 0);
+ return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ return dim * shardCount;
+}
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+// The shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// result in a shape for each shard of ?x2x?.
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
+ MeshShardingAttr sharding);
+
+// If ranked tensor type return its sharded counterpart.
+//
+// If not ranked tensor type return `type`.
+// `sharding` in that case must be null.
+Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 3bef7e6babdec..cc90ddd40a622 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,11 +10,14 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Operation;
+class IRMapping;
+class SymbolTableCollection;
namespace mesh {
@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
} // namespace detail
-} // namespace mesh
+// Assumes full replication on all ranked tensor arguments and results.
+void spmdizeFullyReplicatedOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder);
+} // namespace mesh
} // namespace mlir
/// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 21b6c8d4f599a..4afb1c36a72f7 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
return detail::defaultAddShardingAnnotations(
$_op.getOperation(), b, shardingOption);
}]
- >
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Convert self to SPMD form.
+ This method is used during the spmdization pass of a program fully
+ annotated with shardings.
+
+ The spmdization algorithm would read the surrounding sharding
+ annotations from the IR for each argument/result and prepare
+ `operandShardings` and `resultShardings`.
+ Values that are not ranked tensors do not have sharding annotations.
+ In this case their corresponding MeshShardingAttr is null.
+
+ For convenience it will also prepare `spmdizedOperands`, although
+ they can be retrieved from the `spmdizationMap`.
+
+ The `spmdizationMap` contains a mapping from unsharded to
+ sharded/spmdized values that are constructed during the spmdization
+ pass. The interface implementation must populate `spmdizationMap`
+ with the mapping for this op's results.
+
+ `builder` is set to insert new operations in the appropriate point.
+ The implementation should not return the builder to the original
+ insertion point.
+ It should leave it as is after all insertions are done.
+
+ The default implementation does full replication.
+ This assumes that all sharding annotations are for full replication.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"spmdize",
+ /*args=*/(ins
+ "ArrayRef<Value>": $spmdizedOperands,
+ "ArrayRef<MeshShardingAttr>": $operandShardings,
+ "ArrayRef<MeshShardingAttr>": $resultShardings,
+ "IRMapping&": $spmdizationMap,
+ "SymbolTableCollection &": $symbolTableCollection,
+ "OpBuilder &":$builder
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ spmdizeFullyReplicatedOperation(
+ *$_op.getOperation(), spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap, symbolTableCollection, builder);
+ return success();
+ }]>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..8108386c2e043
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -0,0 +1,125 @@
+//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+
+class Operation;
+class IRMapping;
+class SymbolTableCollection;
+
+namespace mesh {
+
+// Inserts a clone of the operation that has all ranked tensor
+// arguments/results sharded.
+void spmdizeTriviallyShardableOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder);
+
+// All ranked tensor argument and result dimensions have
+// independent parallel loop iterators.
+template <typename Op>
+struct IndependentParallelIteratorDomainShardingInterface
+ : public ShardingInterface::ExternalModel<
+ IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
+ SmallVector<IteratorType> iterTypes;
+ for (Type t : operation->getOperandTypes()) {
+ populateIteratorTypes(t, iterTypes);
+ }
+ for (Type t : operation->getResultTypes()) {
+ populateIteratorTypes(t, iterTypes);
+ }
+ return iterTypes;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ // TODO: implement.
+ return SmallVector<AffineMap>();
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+
+private:
+ void populateIteratorTypes(Type t,
+ SmallVector<IteratorType> &iterTypes) const {
+ RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+ if (!rankedTensorType) {
+ return;
+ }
+
+ iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
+ for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
+ iterTypes.push_back(IteratorType::Parallel);
+ }
+ }
+};
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+template <typename ElemwiseOp>
+struct ElementwiseShardingInterface
+ : public ShardingInterface::ExternalModel<
+ ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ int64_t rank = type.getRank();
+ int64_t num = op->getNumOperands() + op->getNumResults();
+ SmallVector<AffineMap> maps(num,
+ AffineMap::getMultiDimIdentityMap(rank, ctx));
+ return maps;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index c09cf3e710d42..3758b617bc14e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+ let summary = "Partition a function into SPMD form.";
+ let description = [{
+ This pass fits in right after a pass that annotates the function with
+ shardings like the `ShardingPropagation` pass.
+ It operates on a fully annotated IR.
+
+ A fully annotated IR required that all ranked tensor operands, results and
+ block arguments are annotated with the `mesh.shard` operation.
+
+ All direct descendant operations in the function must implement the
+ `ShardingInterface` interface or all their ranked tensor operands and
+ results must have full replication sharding.
+
+ The input IR must have sharding annotations such that each operation
+ that implements `ShardingInterface` can handle during spmdization with
+ its `spmdize` method.
+ This can be achieved with the `ShardingPropagation` pass.
+
+ If the function has multiple terminating blocks,
+ it is the responsibility of the the one who annotates the function with
+ shardings to make sure that all returns would be consisted that is,
+ have the same sharding.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+
+ func.func @f(
+ %arg0: tensor<2xi8>
+ // CHECK-SAME: -> tensor<2xi8> {
+ ) -> tensor<2xi8> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ return %4 : tensor<2xi8>
+ }
+ ```
+ Spmdizing the above would result in
+ * resharding the fully replicated input into splitting it along the only
+ tensor axis.
+ * Performing the element-wise `abs` operation on each device.
+ * Resharding back to full replication with an all-gather.
+
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+ func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
+ // Reshard [[]] -> [[0]]
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %0 = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ %1 = mesh.mesh_shape @mesh_1d axes = [0] : index
+ %2 = arith.remui %c2, %1 : index
+ %3 = arith.cmpi eq, %2, %c0 : index
+ cf.assert %3, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
+ %4 = arith.divui %c2, %1 : index
+ %5 = arith.muli %4, %0 : index
+ %extracted_slice = tensor.extract_slice %arg0[%5] [1] [1] : tensor<2xi8> to tensor<1xi8>
+
+ %6 = tosa.abs %extracted_slice : (tensor<1xi8>) -> tensor<1xi8>
+
+ // Reshard [[0]] -> [[]]
+ %7 = mesh.all_gather %6 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+
+ return %7 : tensor<2xi8>
+ }
+ ```
+ }];
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index 7cb992aac019b..f847ce30a1b40 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -16,16 +16,32 @@
namespace mlir {
namespace mesh {
-// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding);
-
// Insert resharding spmdization of the value `sourceShardValue`
// from sharding `source` to sharding `target`.
// `sourceShardValue` is the already sharded value according to `source`.
+//
+// Example
+//
+// ```mlir
+// mesh.mesh @mesh_1d(shape = 2)
+// ...
+// %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+// %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// ```
+//
+// Will result in
+//
+// ```mlir
+// %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+// tensor<1xi8> -> tensor<2xi8>
+// ```
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue);
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+ ShardOp target,
+ TypedValue<ShapedType> sourceShardValue,
+ SymbolTableCollection &symbolTableCollection);
void reshardingRegisterDependentDialects(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index b3a3a37663a21..eb6b59bb00f1b 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,9 +8,11 @@
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
using namespace mlir;
void mlir::func::registerAllExtensions(DialectRegistry ®istry) {
registerInlinerExtension(registry);
+ registerShardingInterfaceExternalModels(registry);
}
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 8712b1583719b..47363f48f95cc 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
InlinerExtension.cpp
+ MeshShardingExtensions.cpp
)
add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -16,6 +17,19 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
MLIRFuncDialect
)
+add_mlir_extension_library(MLIRFuncMeshShardingExtensions
+ MeshShardingExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRShardingInterface
+ )
+
+
add_mlir_extension_library(MLIRFuncAllExtensions
AllExtensions.cpp
@@ -24,4 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
LINK_LIBS PUBLIC
MLIRFuncInlinerExtension
+ MLIRFuncMeshShardingExtensions
)
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
new file mode 100644
index 0000000000000..da508cc95bfe1
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
@@ -0,0 +1,24 @@
+//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/80518
More information about the Mlir-commits
mailing list