[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)

Boian Petkantchin llvmlistbot at llvm.org
Wed Oct 25 11:26:06 PDT 2023


================
@@ -49,16 +91,63 @@ LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     return failure();
   }
 
-  FailureOr<ShardingOption> shardingOption = shardingOp.getShardingOption();
-  if (failed(shardingOption)) {
-    op->emitOpError() << "fail to get sharding option from results.";
+  // collect MeshShardingAttr from results
+  SmallVector<MeshShardingAttr> resultShardings;
+  resultShardings.reserve(op->getNumResults());
+  for (OpResult result : op->getResults()) {
+    FailureOr<MeshShardingAttr> shardAttr =
+        getMeshShardingAttr(result, /*useOperandSharding*/ true);
+    if (succeeded(shardAttr))
+      resultShardings.push_back(*shardAttr);
+    else
+      resultShardings.push_back(nullptr);
+  }
+
+  // collect MeshShardingAttr from operands
+  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+  allowConflictsOperandShardings.resize(op->getNumOperands());
+  SmallVector<MeshShardingAttr> operandMustShardings;
+  operandMustShardings.resize(op->getNumOperands());
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+        getMeshShardingAttr(opOperand);
+    if (failed(maybeShardAttr))
+      continue;
+
+    bool annotateForUsers = maybeShardAttr->first;
+    if (annotateForUsers)
+      operandMustShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+    else
+      allowConflictsOperandShardings[opOperand.getOperandNumber()] =
+          maybeShardAttr->second;
+  }
+
+  // try to get the sharding option
+  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+      getOrderedPossibleShardingAttrs(operandMustShardings,
+                                      allowConflictsOperandShardings);
+  FailureOr<ShardingOption> finalShardingOption = failure();
+  for (ArrayRef<MeshShardingAttr> operandShardings :
+       possibleOperandShardingAttrs) {
+    FailureOr<ShardingOption> shardingOption =
+        shardingOp.getShardingOption(operandShardings, resultShardings);
----------------
sogartar wrote:

Do we here try to pick first the sharding option that corresponds to the most specific operand shardings?
Why do we treat operands and results differently?

https://github.com/llvm/llvm-project/pull/69665


More information about the Mlir-commits mailing list