[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Boian Petkantchin
llvmlistbot at llvm.org
Wed Oct 25 11:26:06 PDT 2023
================
@@ -49,16 +91,63 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
return failure();
}
- FailureOr<ShardingOption> shardingOption = shardingOp.getShardingOption();
- if (failed(shardingOption)) {
- op->emitOpError() << "fail to get sharding option from results.";
+ // collect MeshShardingAttr from results
+ SmallVector<MeshShardingAttr> resultShardings;
+ resultShardings.reserve(op->getNumResults());
+ for (OpResult result : op->getResults()) {
+ FailureOr<MeshShardingAttr> shardAttr =
+ getMeshShardingAttr(result, /*useOperandSharding*/ true);
+ if (succeeded(shardAttr))
+ resultShardings.push_back(*shardAttr);
+ else
+ resultShardings.push_back(nullptr);
+ }
+
+ // collect MeshShardingAttr from operands
+ SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+ allowConflictsOperandShardings.resize(op->getNumOperands());
+ SmallVector<MeshShardingAttr> operandMustShardings;
+ operandMustShardings.resize(op->getNumOperands());
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+ getMeshShardingAttr(opOperand);
+ if (failed(maybeShardAttr))
+ continue;
+
+ bool annotateForUsers = maybeShardAttr->first;
+ if (annotateForUsers)
+ operandMustShardings[opOperand.getOperandNumber()] =
+ maybeShardAttr->second;
+ else
+ allowConflictsOperandShardings[opOperand.getOperandNumber()] =
+ maybeShardAttr->second;
+ }
+
+ // try to get the sharding option
+ SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+ getOrderedPossibleShardingAttrs(operandMustShardings,
+ allowConflictsOperandShardings);
+ FailureOr<ShardingOption> finalShardingOption = failure();
+ for (ArrayRef<MeshShardingAttr> operandShardings :
+ possibleOperandShardingAttrs) {
+ FailureOr<ShardingOption> shardingOption =
+ shardingOp.getShardingOption(operandShardings, resultShardings);
----------------
sogartar wrote:
Do we here try to pick first the sharding option that corresponds to the most specific operand shardings?
Why do we treat operands and results differently?
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list