[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)

Jakub Kuderski llvmlistbot at llvm.org
Thu Jan 15 07:26:35 PST 2026


================
@@ -663,13 +663,87 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
   return success();
 }
 
+// Check if the block args are correctly annotated with sharding information:
+//   - non-tensor and 0d-tensor args are ignored
+//   - each tensor arg must have exactly one use, which must be a shard.shard
+//   operation
+static LogicalResult checkFullyAnnotated(Block &block) {
+  for (auto arg : block.getArguments()) {
+    auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
+    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+      continue;
+
+    if (rankedTensorArg.getNumUses() > 1)
+      return emitError(block.getParent()->getLoc())
+             << "Cannot partition: expected a single use for block argument "
+             << arg.getArgNumber() << " in block "
+             << block.computeBlockNumber();
+    Operation *useOp = *rankedTensorArg.getUsers().begin();
+    ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+    if (!shardOp)
+      return emitError(block.getParent()->getLoc())
+             << "Cannot partition: expected a shard.shard op for block "
+             << "argument " << arg.getArgNumber() << " in block "
+             << block.computeBlockNumber();
+  }
+  return success();
+}
+
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+//   - Operation results must have exactly one use (e.g. the shard operation).
+//   - All operands and all results must be annotated, e.g. they must be
+//     produced by/consumed by a shard.shard operation.
+//   - Result annotations must not include the 'annotate_for_users' attribute.
+//   - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static LogicalResult checkFullyAnnotated(Operation *op) {
+  // constant ops do not need to have sharding annotations
+  if (op->hasTrait<OpTrait::ConstantLike>())
+    return success();
+
+  for (auto &operand : op->getOpOperands()) {
+    // non-tensor and 0d-tensor operands are ignored
+    auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
+    if (!rankedTT || rankedTT.getRank() == 0)
+      continue;
+
+    auto shard = operand.get().getDefiningOp<ShardOp>();
+    if (!shard)
+      return op->emitError() << "Cannot partition: tensor operand "
+                             << operand.getOperandNumber()
+                             << " must be defined by a shard.shard operation.";
+    else if (!shard.getAnnotateForUsers())
+      return op->emitError()
+             << "Cannot partition: shard.shard for operand "
+             << operand.getOperandNumber() << " must set 'annotate_for_users'.";
+  }
+  for (auto result : op->getResults()) {
+    if (!result.hasOneUse())
+      return op->emitError()
+             << "Cannot partition: result " << result.getResultNumber()
+             << " must have exactly one use.";
+    auto shard = dyn_cast<ShardOp>(*result.user_begin());
+    if (!shard)
+      return op->emitError()
+             << "Cannot partition: user of result " << result.getResultNumber()
+             << " must be shard.shard operation.";
+    else if (shard.getAnnotateForUsers())
----------------
kuhar wrote:

also here

https://github.com/llvm/llvm-project/pull/176000


More information about the Mlir-commits mailing list