[Mlir-commits] [mlir] [mlir][mesh] add support in spmdization for incomplete sharding annotations (PR #82442)
Mehdi Amini
llvmlistbot at llvm.org
Wed Feb 21 18:08:55 PST 2024
================
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- assert(!shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
}
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
- if (shardOp) {
- if (!shardOp.getAnnotateForUsers()) {
- return success();
- }
-
+ Value targetSpmdValue;
+
+ // Check if 2 shard ops are chained. If not there is no need for resharding
+ // as the source and target shared the same sharding.
+ ShardOp srcShardOp =
+ llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
----------------
joker-eph wrote:
Nit: we should not need the `llvm::` prefix I believe.
https://github.com/llvm/llvm-project/pull/82442
More information about the Mlir-commits
mailing list