[mlir] [llvm] [mlir][mesh] Add spmdization pass (PR #80518)

Boian Petkantchin via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 09:14:38 PST 2024


================
@@ -616,10 +577,240 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
       source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
 }
 
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+                               ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue,
+                               SymbolTableCollection &symbolTableCollection) {
+  MeshOp srcMesh = getMesh(source, symbolTableCollection);
+  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
+  return reshard(builder, srcMesh, source, target, sourceShardValue);
+}
+
 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
   registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
                   cf::ControlFlowDialect>();
 }
 
-} // namespace mesh
-} // namespace mlir
+#define GEN_PASS_DEF_SPMDIZATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+using UnshardedToShardedValueMap = DenseMap<Value, Value>;
+
+// Get the types of block arguments for an spmdized block.
+// Reads the sharding annotations of the arguments to deduce the sharded types.
+// Types that are not ranked tensors are left unchanged.
+SmallVector<Type>
+shardedBlockArgumentTypes(Block &block,
+                          SymbolTableCollection &symbolTableCollection) {
+  SmallVector<Type> res;
+  llvm::transform(block.getArguments(), std::back_inserter(res),
+                  [&symbolTableCollection](BlockArgument arg) {
+                    auto rankedTensorArg =
+                        arg.dyn_cast<TypedValue<RankedTensorType>>();
+                    if (!rankedTensorArg) {
+                      return arg.getType();
+                    }
+
+                    assert(rankedTensorArg.hasOneUse());
+                    Operation *useOp = *rankedTensorArg.getUsers().begin();
+                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+                    assert(shardOp);
+                    MeshOp mesh = getMesh(shardOp, symbolTableCollection);
+                    return shardShapedType(rankedTensorArg.getType(), mesh,
+                                           shardOp.getShardAttr())
+                        .cast<Type>();
+                  });
+  return res;
+}
+
+static LogicalResult spmdizeOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+  ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
+  if (!shardingInterface) {
+    // If there is no sharding interface we are conservative and assume that
+    // the op should be fully replicated no all devices.
+    spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
+                                    resultShardings, spmdizationMap,
+                                    symbolTableCollection, builder);
+  } else {
+    if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
+                                         resultShardings, spmdizationMap,
+                                         symbolTableCollection, builder))) {
+      return failure();
+    }
+  }
+
+  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
+    return spmdizationMap.contains(result);
+  }));
+
+  return success();
+}
+
+// Retrieve the sharding annotations for the operands of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
+  SmallVector<MeshShardingAttr> res;
+  res.reserve(op.getNumOperands());
+  llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
+    TypedValue<RankedTensorType> rankedTensor =
+        operand.dyn_cast<TypedValue<RankedTensorType>>();
+    if (!rankedTensor) {
+      return MeshShardingAttr();
+    }
+
+    Operation *definingOp = operand.getDefiningOp();
+    assert(definingOp);
+    ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
+    assert(shardOp.getAnnotateForUsers());
+    return shardOp.getShard();
+  });
+  return res;
+}
+
+// Retrieve the sharding annotations for the results of the given operation.
+// If the type is not a ranked tensor it is not require to have an annotation.
+static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
+  SmallVector<MeshShardingAttr> res;
+  res.reserve(op.getNumResults());
+  llvm::transform(op.getResults(), std::back_inserter(res),
+                  [](OpResult result) {
+                    TypedValue<RankedTensorType> rankedTensor =
+                        result.dyn_cast<TypedValue<RankedTensorType>>();
+                    if (!rankedTensor) {
+                      return MeshShardingAttr();
+                    }
+
+                    assert(result.hasOneUse());
+                    Operation *userOp = *result.getUsers().begin();
+                    ShardOp shardOp = llvm::cast<ShardOp>(userOp);
+                    assert(!shardOp.getAnnotateForUsers());
+                    return shardOp.getShard();
+                  });
+  return res;
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+                 SymbolTableCollection &symbolTableCollection,
+                 OpBuilder &builder) {
+  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+  if (shardOp) {
+    if (!shardOp.getAnnotateForUsers()) {
+      return success();
+    }
+
+    // Insert resharding.
+    ShardOp srcShardOp =
+        llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
+    assert(!srcShardOp.getAnnotateForUsers());
+    TypedValue<ShapedType> srcSpmdValue =
+        spmdizationMap.lookup(srcShardOp.getOperand())
+            .cast<TypedValue<ShapedType>>();
+    Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+                                    symbolTableCollection);
+    assert(!spmdizationMap.contains(shardOp.getResult()));
+    spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+    return success();
+  }
+
+  SmallVector<Value> spmdizedOperands;
+  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
+                  [&spmdizationMap](Value operand) {
+                    assert(spmdizationMap.contains(operand));
+                    return spmdizationMap.lookup(operand);
+                  });
+  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
+                          getResultShardings(op), spmdizationMap,
+                          symbolTableCollection, builder);
+}
+
+static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
+                                  SymbolTableCollection &symbolTableCollection,
+                                  OpBuilder &builder) {
+  SmallVector<Location> argLocations;
+  llvm::transform(block.getArguments(), std::back_inserter(argLocations),
+                  [](BlockArgument arg) { return arg.getLoc(); });
+  Block *newBlock = builder.createBlock(
+      block.getParent(), {},
+      shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
+  for (auto [unshardedBlockArg, spmdizedBlockArg] :
+       llvm::zip(block.getArguments(), newBlock->getArguments())) {
+    spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+  }
+
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointToEnd(newBlock);
+  for (Operation &op : block.getOperations()) {
+    if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
+                                builder))) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+static LogicalResult
+spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+              SymbolTableCollection &symbolTableCollection) {
+  OpBuilder builder(op.getFunctionBody());
+
+  // Snapshot the original blocks to not mess up the iteration when adding new
+  // blocks.
+  SmallVector<Block *> originalBlocks;
+  llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
+                  [](Block &b) { return &b; });
+
+  for (Block *block : originalBlocks) {
+    if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
+                            builder))) {
+      return failure();
+    }
+  }
+
+  for (Block *block : originalBlocks) {
+    block->erase();
+  }
+
+  // Find a return op and change the function results signature to its operands
+  // signature.
+  func::ReturnOp returnOp;
+  for (Block &block : op.getBody()) {
+    if (block.empty()) {
+      continue;
+    }
+
+    returnOp = llvm::cast<func::ReturnOp>(block.back());
+    if (returnOp) {
+      break;
+    }
+  }
+  assert(returnOp);
+  op.setFunctionType(FunctionType::get(op->getContext(),
+                                       op.getBody().front().getArgumentTypes(),
+                                       returnOp->getOperandTypes()));
+
+  return success();
+}
+
+struct Spmdization : public impl::SpmdizationBase<Spmdization> {
----------------
sogartar wrote:

Thank you, fixed it.

https://github.com/llvm/llvm-project/pull/80518


More information about the llvm-commits mailing list