[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)

Frank Schlimbach llvmlistbot at llvm.org
Wed Jan 14 09:51:23 PST 2026


https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/176000

Before trying to partition an operation, check that it is fully annotated with `shard.shard` ops. This gives useful error messages instead of random errors later on.

>From 13bf2ef2306e9516f7ef02e405744c919d44319e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Jan 2026 09:46:36 -0800
Subject: [PATCH] checking for correct&full sharding annotations

---
 .../Dialect/Shard/Transforms/Partition.cpp    | 30 +++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 335ca1a60f8f3..bf0617c107f5e 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
   return success();
 }
 
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+//   - Operation results must have exactly one use (e.g. the shard operation).
+//   - All operands and all results must be annotated, e.g. they must be
+//     produced by/consumed by a shard.shard operation.
+//   - Result annotations must not include the 'annotate_for_users' attribute.
+//   - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static void checkFullyAnnotated(Operation *op) {
+  // constant ops do not need to have sharding annotations
+  if (op->hasTrait<OpTrait::ConstantLike>())
+    return;
+  for (auto operand : op->getOperands()) {
+    if (!operand.getDefiningOp<ShardOp>())
+      op->emitError("Cannot partition: all operands must be produced by a "
+                    "shard.shard operation");
+  }
+  for (auto result : op->getResults()) {
+    if (!result.hasOneUse())
+      op->emitError("Cannot partition: all results must have exactly one use");
+    if (!(isa<ShardOp>(*result.user_begin())))
+      op->emitError(
+          "Cannot partition: all result users must be shard.shard operations");
+  }
+}
+
 static LogicalResult
 partitionOperation(Operation &op, IRMapping &partitionMap,
                    SymbolTableCollection &symbolTableCollection,
@@ -670,6 +696,7 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
   if (isa<ShardingOp>(op)) {
     return success();
   }
+
   if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
     auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
     if (!shardOp) {
@@ -686,6 +713,9 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
                               builder);
   }
 
+  // check if operation is correctly and fully annotated
+  checkFullyAnnotated(&op);
+
   SmallVector<Value> partitionedOperands;
   llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
                   [&partitionMap](Value operand) {



More information about the Mlir-commits mailing list