[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

Frank Schlimbach llvmlistbot at llvm.org
Wed Feb 12 03:01:09 PST 2025


================
@@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
     // No need for anything the correct sharding is already set.
-    return;
+    return newShardOp ? newShardOp : shardOp;
   }
 
-  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
-  auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
-                              /*annotate_for_users*/ false);
+  if (!newShardOp) {
+    auto shardingOp =
+        builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+    newShardOp =
+        builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+                                /*annotate_for_users*/ false);
+  }
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return;
+    return newShardOp;
   }
 
-  auto newShardOp2 =
-      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
-                              /*annotate_for_users*/ true);
+  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+                                             newShardOp.getSharding(),
+                                             /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  return newShardOp;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
+  ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+    newShardOp =
----------------
fschlimb wrote:

Misunderstood your comment. Made the change and let `maybeInsertTargetShardingAnnotation` accept a reference and return void.

https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list