[Mlir-commits] [mlir] f1bf37e - [mlir][shard] Simple fixes to harden sharding propagation and partitioning (#183028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 24 04:17:11 PST 2026


Author: Frank Schlimbach
Date: 2026-02-24T13:17:06+01:00
New Revision: f1bf37e6bea86587b7a55919994174cc8d8ccfab

URL: https://github.com/llvm/llvm-project/commit/f1bf37e6bea86587b7a55919994174cc8d8ccfab
DIFF: https://github.com/llvm/llvm-project/commit/f1bf37e6bea86587b7a55919994174cc8d8ccfab.diff

LOG: [mlir][shard] Simple fixes to harden sharding propagation and partitioning (#183028)

Added: 
    

Modified: 
    mlir/lib/Dialect/Shard/Transforms/Partition.cpp
    mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 8b73bdd7ea60b..9c5880e0c3b64 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -532,7 +532,8 @@ shardedBlockArgumentTypes(Block &block,
       block.getArguments(), std::back_inserter(res),
       [&symbolTableCollection](BlockArgument arg) {
         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
+        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+            rankedTensorArg.use_empty()) {
           return arg.getType();
         }
 
@@ -660,20 +661,22 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
 }
 
 // Check if the block args are correctly annotated with sharding information:
-//   - non-tensor and 0d-tensor args are ignored
+//   - non-tensor, 0d-tensor and unused args are ignored
 //   - each tensor arg must have exactly one use, which must be a shard.shard
-//   operation
+//     operation
 static LogicalResult checkFullyAnnotated(Block &block) {
   for (const BlockArgument &arg : block.getArguments()) {
     auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+        rankedTensorArg.use_empty())
       continue;
 
-    if (rankedTensorArg.getNumUses() > 1)
+    if (!rankedTensorArg.hasOneUse())
       return emitError(block.getParent()->getLoc())
              << "Cannot partition: expected a single use for block argument "
              << arg.getArgNumber() << " in block "
              << block.computeBlockNumber();
+
     Operation *useOp = *rankedTensorArg.getUsers().begin();
     auto shardOp = dyn_cast<ShardOp>(useOp);
     if (!shardOp)

diff  --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index f954131ed7910..cff02d4f03143 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -379,6 +379,10 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
+    // Nothing to propagate if there is no sharding annotation in the block.
+    if (block.getOps<shard::ShardOp>().empty())
+      return;
+
     auto traverse = [&](auto &&range, OpBuilder &builder,
                         const char *order) -> bool {
       for (Operation &op : range) {


        


More information about the Mlir-commits mailing list