[mlir] [llvm] [mlir][mesh] Add spmdization pass (PR #80518)
Boian Petkantchin via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 5 09:14:38 PST 2024
================
@@ -616,10 +577,240 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
}
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+ ShardOp target,
+ TypedValue<ShapedType> sourceShardValue,
+ SymbolTableCollection &symbolTableCollection) {
+ MeshOp srcMesh = getMesh(source, symbolTableCollection);
+ assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
+ return reshard(builder, srcMesh, source, target, sourceShardValue);
+}
+
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
cf::ControlFlowDialect>();
}
-} // namespace mesh
-} // namespace mlir
+#define GEN_PASS_DEF_SPMDIZATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+using UnshardedToShardedValueMap = DenseMap<Value, Value>;
+
+// Get the types of block arguments for an spmdized block.
+// Reads the sharding annotations of the arguments to deduce the sharded types.
+// Types that are not ranked tensors are left unchanged.
+SmallVector<Type>
+shardedBlockArgumentTypes(Block &block,
+ SymbolTableCollection &symbolTableCollection) {
+ SmallVector<Type> res;
+ llvm::transform(block.getArguments(), std::back_inserter(res),
+ [&symbolTableCollection](BlockArgument arg) {
+ auto rankedTensorArg =
+ arg.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensorArg) {
+ return arg.getType();
+ }
+
+ assert(rankedTensorArg.hasOneUse());
+ Operation *useOp = *rankedTensorArg.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+ assert(shardOp);
+ MeshOp mesh = getMesh(shardOp, symbolTableCollection);
+ return shardShapedType(rankedTensorArg.getType(), mesh,
+ shardOp.getShardAttr())
+ .cast<Type>();
+ });
+ return res;
+}
+
+static LogicalResult spmdizeOperation(
+ Operation &op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+ ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
+ if (!shardingInterface) {
+ // If there is no sharding interface we are conservative and assume that
+ // the op should be fully replicated no all devices.
+ spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTableCollection, builder);
+ } else {
+ if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTableCollection, builder))) {
+ return failure();
+ }
+ }
+
+ assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
+ return spmdizationMap.contains(result);
+ }));
+
+ return success();
+}
+
+// Retrieve the sharding annotations for the operands of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
+ SmallVector<MeshShardingAttr> res;
+ res.reserve(op.getNumOperands());
+ llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
+ TypedValue<RankedTensorType> rankedTensor =
+ operand.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensor) {
+ return MeshShardingAttr();
+ }
+
+ Operation *definingOp = operand.getDefiningOp();
+ assert(definingOp);
+ ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
+ assert(shardOp.getAnnotateForUsers());
+ return shardOp.getShard();
+ });
+ return res;
+}
+
+// Retrieve the sharding annotations for the results of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
+ SmallVector<MeshShardingAttr> res;
+ res.reserve(op.getNumResults());
+ llvm::transform(op.getResults(), std::back_inserter(res),
+ [](OpResult result) {
+ TypedValue<RankedTensorType> rankedTensor =
+ result.dyn_cast<TypedValue<RankedTensorType>>();
+ if (!rankedTensor) {
+ return MeshShardingAttr();
+ }
+
+ assert(result.hasOneUse());
+ Operation *userOp = *result.getUsers().begin();
+ ShardOp shardOp = llvm::cast<ShardOp>(userOp);
+ assert(!shardOp.getAnnotateForUsers());
+ return shardOp.getShard();
+ });
+ return res;
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ if (!shardOp.getAnnotateForUsers()) {
+ return success();
+ }
+
+ // Insert resharding.
+ ShardOp srcShardOp =
+ llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
+ assert(!srcShardOp.getAnnotateForUsers());
+ TypedValue<ShapedType> srcSpmdValue =
+ spmdizationMap.lookup(srcShardOp.getOperand())
+ .cast<TypedValue<ShapedType>>();
+ Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ symbolTableCollection);
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+ }
+
+ SmallVector<Value> spmdizedOperands;
+ llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
+ [&spmdizationMap](Value operand) {
+ assert(spmdizationMap.contains(operand));
+ return spmdizationMap.lookup(operand);
+ });
+ return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
+ getResultShardings(op), spmdizationMap,
+ symbolTableCollection, builder);
+}
+
+static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ SmallVector<Location> argLocations;
+ llvm::transform(block.getArguments(), std::back_inserter(argLocations),
+ [](BlockArgument arg) { return arg.getLoc(); });
+ Block *newBlock = builder.createBlock(
+ block.getParent(), {},
+ shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
+ for (auto [unshardedBlockArg, spmdizedBlockArg] :
+ llvm::zip(block.getArguments(), newBlock->getArguments())) {
+ spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+ }
+
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(newBlock);
+ for (Operation &op : block.getOperations()) {
+ if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
+ builder))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+static LogicalResult
+spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection) {
+ OpBuilder builder(op.getFunctionBody());
+
+ // Snapshot the original blocks to not mess up the iteration when adding new
+ // blocks.
+ SmallVector<Block *> originalBlocks;
+ llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
+ [](Block &b) { return &b; });
+
+ for (Block *block : originalBlocks) {
+ if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
+ builder))) {
+ return failure();
+ }
+ }
+
+ for (Block *block : originalBlocks) {
+ block->erase();
+ }
+
+ // Find a return op and change the function results signature to its operands
+ // signature.
+ func::ReturnOp returnOp;
+ for (Block &block : op.getBody()) {
+ if (block.empty()) {
+ continue;
+ }
+
+ returnOp = llvm::cast<func::ReturnOp>(block.back());
+ if (returnOp) {
+ break;
+ }
+ }
+ assert(returnOp);
+ op.setFunctionType(FunctionType::get(op->getContext(),
+ op.getBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
+
+ return success();
+}
+
+struct Spmdization : public impl::SpmdizationBase<Spmdization> {
----------------
sogartar wrote:
Thank you, fixed it.
https://github.com/llvm/llvm-project/pull/80518
More information about the llvm-commits
mailing list