[llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)
Boian Petkantchin via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 18:31:51 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/80518
>From e3b9f4a599d5485c12c9e8a2c9c11ac37100fc8a Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 26 Jan 2024 09:47:46 -0800
Subject: [PATCH] [mlir][mesh] Add spmdization pass
Add a pass that converts a function that has sharding annotations into SPMD
form.
---
.../Func/Extensions/MeshShardingExtensions.h | 23 ++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 45 +++
.../Mesh/Interfaces/ShardingInterface.h | 11 +-
.../Mesh/Interfaces/ShardingInterface.td | 47 ++-
.../Mesh/Interfaces/ShardingInterfaceImpl.h | 125 ++++++++
.../mlir/Dialect/Mesh/Transforms/Passes.td | 58 ++++
.../Dialect/Mesh/Transforms/Spmdization.h | 24 +-
.../Dialect/Func/Extensions/AllExtensions.cpp | 2 +
.../Dialect/Func/Extensions/CMakeLists.txt | 15 +
.../Extensions/MeshShardingExtensions.cpp | 24 ++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 32 ++
.../Mesh/Interfaces/ShardingInterface.cpp | 65 +++-
.../Mesh/Transforms/ShardingPropagation.cpp | 6 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 299 +++++++++++++++---
.../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 30 +-
.../Dialect/Mesh/resharding-spmdization.mlir | 4 +-
mlir/test/Dialect/Mesh/spmdization.mlir | 131 ++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 8 +-
18 files changed, 848 insertions(+), 101 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
create mode 100644 mlir/test/Dialect/Mesh/spmdization.mlir
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
new file mode 100644
index 00000000000000..9b7abbca5d7622
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 8e5e0f541ba5ee..9154e6fd803102 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,11 +10,13 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/MathExtras.h"
namespace mlir {
namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
Partial getPartialTypeFromReduction(IteratorType iType);
+// Is the same tensor replicated on all processes.
+inline bool isFullReplication(MeshShardingAttr attr) {
+ return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+}
+
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
}
+template <>
+inline mesh::MeshOp
+getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+ return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+ symbolTableCollection);
+}
+
// Get the number of processes that participate in each group
// induced by `meshAxes`.
template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,37 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+// Get the size of a sharded dimension.
+inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
+ if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ assert(dimSize % shardCount == 0);
+ return ceilDiv(dimSize, shardCount);
+}
+
+// Get the size of an unsharded dimension.
+inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
+ if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ return dimSize * shardCount;
+}
+
+// Return the sharded shape `shape` according ot sharding `sharding`.
+// The shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// result in a shape for each shard of ?x2x?.
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
+ MeshShardingAttr sharding);
+
+// If ranked tensor type return its sharded counterpart.
+//
+// If not ranked tensor type return `type`.
+// `sharding` in that case must be null.
+Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 3bef7e6babdec9..cc90ddd40a6222 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,11 +10,14 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Operation;
+class IRMapping;
+class SymbolTableCollection;
namespace mesh {
@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
} // namespace detail
-} // namespace mesh
+// Assumes full replication on all ranked tensor arguments and results.
+void spmdizeFullyReplicatedOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder);
+} // namespace mesh
} // namespace mlir
/// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 21b6c8d4f599a8..4afb1c36a72f7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
return detail::defaultAddShardingAnnotations(
$_op.getOperation(), b, shardingOption);
}]
- >
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Convert self to SPMD form.
+ This method is used during the spmdization pass of a program fully
+ annotated with shardings.
+
+ The spmdization algorithm would read the surrounding sharding
+ annotations from the IR for each argument/result and prepare
+ `operandShardings` and `resultShardings`.
+ Values that are not ranked tensors do not have sharding annotations.
+ In this case their corresponding MeshShardingAttr is null.
+
+ For convenience it will also prepare `spmdizedOperands`, although
+ they can be retrieved from the `spmdizationMap`.
+
+ The `spmdizationMap` contains a mapping from unsharded to
+ sharded/spmdized values that are constructed during the spmdization
+ pass. The interface implementation must populate `spmdizationMap`
+ with the mapping for this op's results.
+
+ `builder` is set to insert new operations in the appropriate point.
+ The implementation should not return the builder to the original
+ insertion point.
+ It should leave it as is after all insertions are done.
+
+ The default implementation does full replication.
+ This assumes that all sharding annotations are for full replication.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"spmdize",
+ /*args=*/(ins
+ "ArrayRef<Value>": $spmdizedOperands,
+ "ArrayRef<MeshShardingAttr>": $operandShardings,
+ "ArrayRef<MeshShardingAttr>": $resultShardings,
+ "IRMapping&": $spmdizationMap,
+ "SymbolTableCollection &": $symbolTableCollection,
+ "OpBuilder &":$builder
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ spmdizeFullyReplicatedOperation(
+ *$_op.getOperation(), spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap, symbolTableCollection, builder);
+ return success();
+ }]>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
new file mode 100644
index 00000000000000..8108386c2e0437
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -0,0 +1,125 @@
+//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+
+class Operation;
+class IRMapping;
+class SymbolTableCollection;
+
+namespace mesh {
+
+// Inserts a clone of the operation that has all ranked tensor
+// arguments/results sharded.
+void spmdizeTriviallyShardableOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder);
+
+// All ranked tensor argument and result dimensions have
+// independent parallel loop iterators.
+template <typename Op>
+struct IndependentParallelIteratorDomainShardingInterface
+ : public ShardingInterface::ExternalModel<
+ IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
+ SmallVector<IteratorType> iterTypes;
+ for (Type t : operation->getOperandTypes()) {
+ populateIteratorTypes(t, iterTypes);
+ }
+ for (Type t : operation->getResultTypes()) {
+ populateIteratorTypes(t, iterTypes);
+ }
+ return iterTypes;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ // TODO: implement.
+ return SmallVector<AffineMap>();
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+
+private:
+ void populateIteratorTypes(Type t,
+ SmallVector<IteratorType> &iterTypes) const {
+ RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+ if (!rankedTensorType) {
+ return;
+ }
+
+ iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
+ for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
+ iterTypes.push_back(IteratorType::Parallel);
+ }
+ }
+};
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+template <typename ElemwiseOp>
+struct ElementwiseShardingInterface
+ : public ShardingInterface::ExternalModel<
+ ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ int64_t rank = type.getRank();
+ int64_t num = op->getNumOperands() + op->getNumResults();
+ SmallVector<AffineMap> maps(num,
+ AffineMap::getMultiDimIdentityMap(rank, ctx));
+ return maps;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index c09cf3e710d427..7fb6631574b410 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -29,4 +29,62 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+ let summary = "Partition a function into SPMD form.";
+ let description = [{
+ This pass fits in right after a pass that annotates the function with
+ shardings like the `ShardingPropagation` pass.
+ It operates on a fully annotated IR.
+
+ A fully annotated IR required that all ranked tensor operands, results and
+ block arguments are annotated with the `mesh.shard` operation.
+
+ All direct descendant operations in the function must implement the
+ `ShardingInterface` interface or all their ranked tensor operands and
+ results must have full replication sharding.
+
+ The input IR must have sharding annotations such that each operation
+ that implements `ShardingInterface` can handle during spmdization with
+ its `spmdize` method.
+ This can be achieved with the `ShardingPropagation` pass.
+
+ If the function has multiple terminating blocks,
+ it is the responsibility of the the one who annotates the function with
+ shardings to make sure that all returns would be consisted that is,
+ have the same sharding.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+
+ func.func @f(
+ %arg0: tensor<2xi8>
+ ) -> tensor<2xi8> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ return %4 : tensor<2xi8>
+ }
+ ```
+ Spmdizing the above would result in
+ * Performing the element-wise `abs` operation on each device.
+ * Resharding to full replication with an all-gather.
+
+ ```mlir
+ mesh.mesh @mesh_1d(shape = 2)
+
+ func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
+ %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
+ %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ return %1 : tensor<2xi8>
+ }
+ ```
+ }];
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index 7cb992aac019b3..f847ce30a1b40f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -16,16 +16,32 @@
namespace mlir {
namespace mesh {
-// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding);
-
// Insert resharding spmdization of the value `sourceShardValue`
// from sharding `source` to sharding `target`.
// `sourceShardValue` is the already sharded value according to `source`.
+//
+// Example
+//
+// ```mlir
+// mesh.mesh @mesh_1d(shape = 2)
+// ...
+// %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+// %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// ```
+//
+// Will result in
+//
+// ```mlir
+// %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+// tensor<1xi8> -> tensor<2xi8>
+// ```
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue);
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+ ShardOp target,
+ TypedValue<ShapedType> sourceShardValue,
+ SymbolTableCollection &symbolTableCollection);
void reshardingRegisterDependentDialects(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index b3a3a37663a215..eb6b59bb00f1bc 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,9 +8,11 @@
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
using namespace mlir;
void mlir::func::registerAllExtensions(DialectRegistry ®istry) {
registerInlinerExtension(registry);
+ registerShardingInterfaceExternalModels(registry);
}
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 8712b1583719b8..47363f48f95cc4 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
InlinerExtension.cpp
+ MeshShardingExtensions.cpp
)
add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -16,6 +17,19 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
MLIRFuncDialect
)
+add_mlir_extension_library(MLIRFuncMeshShardingExtensions
+ MeshShardingExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRShardingInterface
+ )
+
+
add_mlir_extension_library(MLIRFuncAllExtensions
AllExtensions.cpp
@@ -24,4 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
LINK_LIBS PUBLIC
MLIRFuncInlinerExtension
+ MLIRFuncMeshShardingExtensions
)
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
new file mode 100644
index 00000000000000..da508cc95bfe13
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
@@ -0,0 +1,24 @@
+//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir::func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
+ ReturnOp::attachInterface<
+ mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
+ *ctx);
+ });
+}
+
+} // namespace mlir::func
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 7c6a5e02ed5830..a65b8f2e5a2376 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -175,6 +175,38 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
}
}
+template <typename InShape, typename MeshShape, typename SplitAxes,
+ typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+ const SplitAxes &splitAxes, OutShape &outShape) {
+ std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+ llvm::adl_begin(outShape));
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ outShape[tensorAxis] = shardDimension(
+ inShape[tensorAxis],
+ collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+ }
+}
+
+ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
+ MeshShardingAttr sharding) {
+ using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+ SmallVector<Dim> resShapeArr(shape.getShape().size());
+ shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+ resShapeArr);
+ return shape.clone(resShapeArr);
+}
+
+Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
+ RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
+ if (rankedTensorType) {
+ return shardShapedType(rankedTensorType, mesh, sharding);
+ }
+
+ assert(!sharding);
+ return type;
+}
+
//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 5dc91ff1c02d20..a0a573f65a3467 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -7,14 +7,16 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
-#include <algorithm>
#include <utility>
#define DEBUG_TYPE "sharding-interface"
@@ -391,8 +393,6 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// detail::defaultAddShardingAnnotations
//===----------------------------------------------------------------------===//
-namespace {
-
// To add a `mesh.shard` op for the given result, based on the details provided
// in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
@@ -493,8 +493,6 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
return success();
}
-} // namespace
-
LogicalResult mesh::detail::defaultAddShardingAnnotations(
Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
@@ -519,3 +517,58 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
return success();
}
+
+static bool
+isValueCompatibleWithFullReplicationSharding(Value value,
+ MeshShardingAttr sharding) {
+ if (value.getType().isa<RankedTensorType>()) {
+ return sharding && isFullReplication(sharding);
+ }
+
+ return !sharding;
+}
+
+template <typename ValueRange, typename MeshShardingAttrRage>
+static bool areValuesCompatibleWithFullReplicationShardings(
+ ValueRange &&values, MeshShardingAttrRage &&shardings) {
+ if (std::size(values) != std::size(shardings)) {
+ return false;
+ }
+ return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
+ std::forward<MeshShardingAttrRage>(shardings)),
+ [](auto valueAndSharding) {
+ return isValueCompatibleWithFullReplicationSharding(
+ std::get<0>(valueAndSharding),
+ std::get<1>(valueAndSharding));
+ });
+}
+
+void mesh::spmdizeFullyReplicatedOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder) {
+ assert(spmdizedOperands.size() == operandShardings.size());
+ assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
+ operandShardings));
+ assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
+ resultShardings));
+ // `clone` will populate the mapping of old to new results.
+ builder.clone(op, spmdizationMap);
+}
+
+void mesh::spmdizeTriviallyShardableOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable, OpBuilder &builder) {
+ // `clone` will populate the mapping of old to new results.
+ Operation *newOp = builder.clone(op, spmdizationMap);
+ // Set the result types to the sharded counterparts.
+ for (auto [oldResult, newResult, sharding] :
+ llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
+ newResult.setType(shardType(newResult.getType(),
+ getMesh(&op, sharding.getMesh(), symbolTable),
+ sharding));
+ }
+}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index f3cd12f38879d8..9f2647b21cbfc8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -29,8 +29,6 @@ namespace mesh {
using namespace mlir;
using namespace mlir::mesh;
-namespace {
-
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
@@ -83,7 +81,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
// `getShardingOption` method. If the inferred sharding option is not empty, add
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
-LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
return success();
@@ -207,5 +205,3 @@ struct ShardingPropagation
return signalPassFailure();
}
};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 5554edac4d2f63..b649157a9e46de 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -7,79 +7,40 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/MathExtras.h"
-#include "llvm/ADT/ADL.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <algorithm>
+#include "llvm/Support/Casting.h"
#include <iterator>
-#include <numeric>
#include <optional>
#include <tuple>
#include <type_traits>
-namespace mlir {
-namespace mesh {
-
-int64_t shardDimension(int64_t dim, int64_t shardCount) {
- if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
- return ShapedType::kDynamic;
-
- assert(dim % shardCount == 0);
- return ceilDiv(dim, shardCount);
-}
-
-int64_t unshardDimension(int64_t dim, int64_t shardCount) {
- if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
- return ShapedType::kDynamic;
-
- return dim * shardCount;
-}
-
-// Compute the shape for the tensor on each device in the mesh.
-// Example:
-// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
-// would result in a shape for each shard of ?x2x?.
-template <typename InShape, typename MeshShape, typename SplitAxes,
- typename OutShape>
-static void shardShape(const InShape &inShape, const MeshShape &meshShape,
- const SplitAxes &splitAxes, OutShape &outShape) {
- std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
- llvm::adl_begin(outShape));
- for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- outShape[tensorAxis] = shardDimension(
- inShape[tensorAxis],
- collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
- }
-}
-
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding) {
- using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
- SmallVector<Dim> resShapeArr(shape.getShape().size());
- shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
- resShapeArr);
- return shape.clone(resShapeArr);
-}
+namespace mlir::mesh {
template <typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
@@ -362,7 +323,7 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
targetShape[splitTensorAxis] =
- unshardDimension(targetShape[splitTensorAxis], splitCount);
+ gatherDimension(targetShape[splitTensorAxis], splitCount);
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
@@ -493,7 +454,7 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
int64_t targetTensorAxis) {
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
targetShape[sourceTensorAxis] =
- unshardDimension(targetShape[sourceTensorAxis], splitCount);
+ gatherDimension(targetShape[sourceTensorAxis], splitCount);
targetShape[targetTensorAxis] =
shardDimension(targetShape[targetTensorAxis], splitCount);
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
@@ -616,10 +577,244 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
}
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+ ShardOp target,
+ TypedValue<ShapedType> sourceShardValue,
+ SymbolTableCollection &symbolTableCollection) {
+ MeshOp srcMesh = getMesh(source, symbolTableCollection);
+ assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
+ return reshard(builder, srcMesh, source, target, sourceShardValue);
+}
+
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
cf::ControlFlowDialect>();
}
-} // namespace mesh
-} // namespace mlir
+#define GEN_PASS_DEF_SPMDIZATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+using UnshardedToShardedValueMap = DenseMap<Value, Value>;
+
+// Get the types of block arguments for an spmdized block.
+// Reads the sharding annotations of the arguments to deduce the sharded types.
+// Types that are not ranked tensors are left unchanged.
+SmallVector<Type>
+shardedBlockArgumentTypes(Block &block,
+ SymbolTableCollection &symbolTableCollection) {
+ SmallVector<Type> res;
+ llvm::transform(block.getArguments(), std::back_inserter(res),
+ [&symbolTableCollection](BlockArgument arg) {
+ auto rankedTensorArg =
+ arg.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensorArg) {
+ return arg.getType();
+ }
+
+ assert(rankedTensorArg.hasOneUse());
+ Operation *useOp = *rankedTensorArg.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+ assert(shardOp);
+ MeshOp mesh = getMesh(shardOp, symbolTableCollection);
+ return shardShapedType(rankedTensorArg.getType(), mesh,
+ shardOp.getShardAttr())
+ .cast<Type>();
+ });
+ return res;
+}
+
+static LogicalResult spmdizeOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+ ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
+ if (!shardingInterface) {
+ // If there is no sharding interface we are conservative and assume that
+ // the op should be fully replicated no all devices.
+ spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTableCollection, builder);
+ } else {
+ if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTableCollection, builder))) {
+ return failure();
+ }
+ }
+
+ assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
+ return spmdizationMap.contains(result);
+ }));
+
+ return success();
+}
+
+// Retrieve the sharding annotations for the operands of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
+ SmallVector<MeshShardingAttr> res;
+ res.reserve(op.getNumOperands());
+ llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
+ TypedValue<RankedTensorType> rankedTensor =
+ operand.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensor) {
+ return MeshShardingAttr();
+ }
+
+ Operation *definingOp = operand.getDefiningOp();
+ assert(definingOp);
+ ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
+ assert(shardOp.getAnnotateForUsers());
+ return shardOp.getShard();
+ });
+ return res;
+}
+
+// Retrieve the sharding annotations for the results of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
+ SmallVector<MeshShardingAttr> res;
+ res.reserve(op.getNumResults());
+ llvm::transform(op.getResults(), std::back_inserter(res),
+ [](OpResult result) {
+ TypedValue<RankedTensorType> rankedTensor =
+ result.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensor) {
+ return MeshShardingAttr();
+ }
+
+ assert(result.hasOneUse());
+ Operation *userOp = *result.getUsers().begin();
+ ShardOp shardOp = llvm::cast<ShardOp>(userOp);
+ assert(!shardOp.getAnnotateForUsers());
+ return shardOp.getShard();
+ });
+ return res;
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ if (!shardOp.getAnnotateForUsers()) {
+ return success();
+ }
+
+ // Insert resharding.
+ ShardOp srcShardOp =
+ llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
+ assert(!srcShardOp.getAnnotateForUsers());
+ TypedValue<ShapedType> srcSpmdValue =
+ spmdizationMap.lookup(srcShardOp.getOperand())
+ .cast<TypedValue<ShapedType>>();
+ Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ symbolTableCollection);
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+ }
+
+ SmallVector<Value> spmdizedOperands;
+ llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
+ [&spmdizationMap](Value operand) {
+ assert(spmdizationMap.contains(operand));
+ return spmdizationMap.lookup(operand);
+ });
+ return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
+ getResultShardings(op), spmdizationMap,
+ symbolTableCollection, builder);
+}
+
+static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ SmallVector<Location> argLocations;
+ llvm::transform(block.getArguments(), std::back_inserter(argLocations),
+ [](BlockArgument arg) { return arg.getLoc(); });
+ Block *newBlock = builder.createBlock(
+ block.getParent(), {},
+ shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
+ for (auto [unshardedBlockArg, spmdizedBlockArg] :
+ llvm::zip(block.getArguments(), newBlock->getArguments())) {
+ spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+ }
+
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(newBlock);
+ for (Operation &op : block.getOperations()) {
+ if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
+ builder))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+static LogicalResult
+spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection) {
+ OpBuilder builder(op.getFunctionBody());
+
+ // Snapshot the original blocks to not mess up the iteration when adding new
+ // blocks.
+ SmallVector<Block *> originalBlocks;
+ llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
+ [](Block &b) { return &b; });
+
+ for (Block *block : originalBlocks) {
+ if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
+ builder))) {
+ return failure();
+ }
+ }
+
+ for (Block *block : originalBlocks) {
+ block->erase();
+ }
+
+ // Find a return op and change the function results signature to its operands
+ // signature.
+ func::ReturnOp returnOp;
+ for (Block &block : op.getBody()) {
+ if (block.empty()) {
+ continue;
+ }
+
+ returnOp = llvm::cast<func::ReturnOp>(block.back());
+ if (returnOp) {
+ break;
+ }
+ }
+ assert(returnOp);
+ op.setFunctionType(FunctionType::get(op->getContext(),
+ op.getBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
+
+ return success();
+}
+
+namespace {
+
+struct Spmdization : public impl::SpmdizationBase<Spmdization> {
+ void runOnOperation() override {
+ IRMapping spmdizationMap;
+ SymbolTableCollection symbolTableCollection;
+ if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
+ symbolTableCollection))) {
+ return signalPassFailure();
+ }
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ reshardingRegisterDependentDialects(registry);
+ registry.insert<mesh::MeshDialect>();
+ }
+};
+
+} // namespace
+
+} // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index dace86533c0e23..fefbca8c2e091a 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/DialectRegistry.h"
@@ -23,33 +24,6 @@ using namespace mlir::mesh;
namespace {
-template <typename ElemwiseOp>
-struct ElemwiseSharding
- : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
- ElemwiseOp> {
- SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
- Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
- if (!type)
- return {};
- SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
- return types;
- }
-
- SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
- MLIRContext *ctx = op->getContext();
- Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
- if (!type)
- return {};
- int64_t rank = type.getRank();
- int64_t num = op->getNumOperands() + op->getNumResults();
- SmallVector<AffineMap> maps(num,
- AffineMap::getMultiDimIdentityMap(rank, ctx));
- return maps;
- }
-};
-
// loop types: [parallel, parallel, parallel, reduction_sum]
// indexing maps:
// (d0, d1, d2, d3) -> (d0, d1, d3)
@@ -83,7 +57,7 @@ struct MatMulOpSharding
template <typename OpType>
static void registerElemwiseOne(MLIRContext *ctx) {
- OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+ OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
}
/// Variadic helper function.
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index cb98d31dd6692d..f1d686135c28eb 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -141,8 +141,8 @@ func.func @unshard_static_axis_on_dynamic_mesh_axis(
return %1 : tensor<10x14xf32>
}
-// CHECK-LABEL: func @partial_axis
-func.func @partial_axis(
+// CHECK-LABEL: func @partial_axis_to_full_replication
+func.func @partial_axis_to_full_replication(
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
new file mode 100644
index 00000000000000..9993c1518e9eae
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @full_replication
+func.func @full_replication(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[ARG]] : tensor<2xi8>
+ return %1 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
+ %arg0: tensor<2x2xi8>
+// CHECK-SAME: -> tensor<2x1xi8> {
+) -> tensor<2x2xi8> {
+ // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
+ // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2x2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<2x2xi8>
+ // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
+ return %1 : tensor<2x2xi8>
+}
+
+// CHECK-LABEL: func @non_tensor_value
+func.func @non_tensor_value(
+ // CHECK-SAME: %[[ARG:.*]]: i8
+ %arg0: i8
+// CHECK-SAME: -> i8 {
+) -> i8 {
+ // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
+ %0 = arith.addi %arg0, %arg0 : i8
+ // CHECK: return %[[RES]] : i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func @unary_elementwise
+func.func @unary_elementwise(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %4 : tensor<2xi8>
+}
+
+// full replication -> shard axis -> abs -> shard axis -> full replication
+// CHECK-LABEL: func @unary_elementwise_with_resharding
+func.func @unary_elementwise_with_resharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ // We don't care about the whole resharding IR, just that it happens.
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
+ // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
+ // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<2xi8>
+ return %4 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @binary_elementwise
+func.func @binary_elementwise(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
+ %arg0: tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
+ %arg1: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %op_arg0 = mesh.shard %arg0_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %arg1_sharded = mesh.shard %arg1 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %op_arg1 = mesh.shard %arg1_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
+ %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
+ %op_res_sharded = mesh.shard %op_res to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %res = mesh.shard %op_res_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %res : tensor<2xi8>
+}
+
+// reshard
+// abs
+// reshard
+// abs
+// reshard
+// CHECK-LABEL: func @multiple_chained_ops
+func.func @multiple_chained_ops(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ // We don't care about the whole resharding IR, just that it happens.
+ // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
+ // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
+ // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
+ %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
+ // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
+ %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RESHARD3]] : tensor<1xi8>
+ return %7 : tensor<2xi8>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1d79587512539f..9f2c1f1120066b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3253,7 +3253,10 @@ gentbl_cc_library(
cc_library(
name = "MeshShardingInterface",
srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"],
- hdrs = ["include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"],
+ hdrs = [
+ "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h",
+ "include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h",
+ ],
includes = ["include"],
deps = [
":DialectUtils",
@@ -4549,8 +4552,9 @@ cc_library(
deps = [
":ControlFlowDialect",
":FuncDialect",
- ":IR",
":InferTypeOpInterface",
+ ":IR",
+ ":MeshShardingInterface",
],
)
More information about the llvm-commits
mailing list