[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jan 15 07:26:34 PST 2026
================
@@ -663,13 +663,87 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
return success();
}
+// Check if the block args are correctly annotated with sharding information:
+// - non-tensor and 0d-tensor args are ignored
+// - each tensor arg must have exactly one use, which must be a shard.shard
+// operation
+static LogicalResult checkFullyAnnotated(Block &block) {
+ for (auto arg : block.getArguments()) {
+ auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+ continue;
+
+ if (rankedTensorArg.getNumUses() > 1)
+ return emitError(block.getParent()->getLoc())
+ << "Cannot partition: expected a single use for block argument "
+ << arg.getArgNumber() << " in block "
+ << block.computeBlockNumber();
+ Operation *useOp = *rankedTensorArg.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+ if (!shardOp)
+ return emitError(block.getParent()->getLoc())
+ << "Cannot partition: expected a shard.shard op for block "
+ << "argument " << arg.getArgNumber() << " in block "
+ << block.computeBlockNumber();
+ }
+ return success();
+}
+
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+// - Operation results must have exactly one use (e.g. the shard operation).
+// - All operands and all results must be annotated, e.g. they must be
+// produced by/consumed by a shard.shard operation.
+// - Result annotations must not include the 'annotate_for_users' attribute.
+// - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static LogicalResult checkFullyAnnotated(Operation *op) {
+ // constant ops do not need to have sharding annotations
+ if (op->hasTrait<OpTrait::ConstantLike>())
+ return success();
+
+ for (auto &operand : op->getOpOperands()) {
+ // non-tensor and 0d-tensor operands are ignored
+ auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
+ if (!rankedTT || rankedTT.getRank() == 0)
+ continue;
+
+ auto shard = operand.get().getDefiningOp<ShardOp>();
+ if (!shard)
+ return op->emitError() << "Cannot partition: tensor operand "
+ << operand.getOperandNumber()
+ << " must be defined by a shard.shard operation.";
+ else if (!shard.getAnnotateForUsers())
----------------
kuhar wrote:
no else after return: https://llvm.org/docs/CodingStandards.html#don-t-use-else-after-a-return
https://github.com/llvm/llvm-project/pull/176000
More information about the Mlir-commits
mailing list