[Mlir-commits] [mlir] [mlir][shard] Hardening sharding propagation and partitioning (PR #183028)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Feb 24 03:26:23 PST 2026
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/183028
>From 739beec727660e76927988daac67956d7c674afb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 24 Feb 2026 03:16:30 -0800
Subject: [PATCH] hardening sharding propagation and partitioning
---
mlir/lib/Dialect/Shard/Transforms/Partition.cpp | 7 +++++--
mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp | 4 ++++
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 8b73bdd7ea60b..1224fdef164dd 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -532,7 +532,8 @@ shardedBlockArgumentTypes(Block &block,
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
- if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+ rankedTensorArg.use_empty()) {
return arg.getType();
}
@@ -666,7 +667,8 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
static LogicalResult checkFullyAnnotated(Block &block) {
for (const BlockArgument &arg : block.getArguments()) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
- if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
+ rankedTensorArg.use_empty())
continue;
if (rankedTensorArg.getNumUses() > 1)
@@ -674,6 +676,7 @@ static LogicalResult checkFullyAnnotated(Block &block) {
<< "Cannot partition: expected a single use for block argument "
<< arg.getArgNumber() << " in block "
<< block.computeBlockNumber();
+
Operation *useOp = *rankedTensorArg.getUsers().begin();
auto shardOp = dyn_cast<ShardOp>(useOp);
if (!shardOp)
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index f954131ed7910..cff02d4f03143 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -379,6 +379,10 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
+ // Nothing to propagate if there is no sharding annotation in the block.
+ if (block.getOps<shard::ShardOp>().empty())
+ return;
+
auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
More information about the Mlir-commits
mailing list