[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 14 09:51:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
Before trying to partition an operation, check that it is fully annotated with `shard.shard` ops. This gives useful error messages instead of random errors later on.
---
Full diff: https://github.com/llvm/llvm-project/pull/176000.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+30)
``````````diff
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 335ca1a60f8f3..bf0617c107f5e 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
return success();
}
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+// - Operation results must have exactly one use (e.g. the shard operation).
+// - All operands and all results must be annotated, e.g. they must be
+// produced by/consumed by a shard.shard operation.
+// - Result annotations must not include the 'annotate_for_users' attribute.
+// - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static void checkFullyAnnotated(Operation *op) {
+ // constant ops do not need to have sharding annotations
+ if (op->hasTrait<OpTrait::ConstantLike>())
+ return;
+ for (auto operand : op->getOperands()) {
+ if (!operand.getDefiningOp<ShardOp>())
+ op->emitError("Cannot partition: all operands must be produced by a "
+ "shard.shard operation");
+ }
+ for (auto result : op->getResults()) {
+ if (!result.hasOneUse())
+ op->emitError("Cannot partition: all results must have exactly one use");
+ if (!(isa<ShardOp>(*result.user_begin())))
+ op->emitError(
+ "Cannot partition: all result users must be shard.shard operations");
+ }
+}
+
static LogicalResult
partitionOperation(Operation &op, IRMapping &partitionMap,
SymbolTableCollection &symbolTableCollection,
@@ -670,6 +696,7 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
if (isa<ShardingOp>(op)) {
return success();
}
+
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
if (!shardOp) {
@@ -686,6 +713,9 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
builder);
}
+ // check if operation is correctly and fully annotated
+ checkFullyAnnotated(&op);
+
SmallVector<Value> partitionedOperands;
llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
[&partitionMap](Value operand) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/176000
More information about the Mlir-commits
mailing list