[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)
Renato Golin
llvmlistbot at llvm.org
Tue Feb 11 09:15:38 PST 2025
================
@@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
- return;
+ return newShardOp ? newShardOp : shardOp;
}
- auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
- auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
- /*annotate_for_users*/ false);
+ if (!newShardOp) {
+ auto shardingOp =
+ builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+ newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+ /*annotate_for_users*/ false);
+ }
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});
if (!shardOp || shardOp.getAnnotateForUsers()) {
- return;
+ return newShardOp;
}
- auto newShardOp2 =
- builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
- /*annotate_for_users*/ true);
+ auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+ newShardOp.getSharding(),
+ /*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+ return newShardOp;
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
+ ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
- maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ newShardOp =
----------------
rengolin wrote:
What's the use of this returned variable?
https://github.com/llvm/llvm-project/pull/124724
More information about the Mlir-commits
mailing list