[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

Frank Schlimbach llvmlistbot at llvm.org
Wed Feb 12 01:29:20 PST 2025


================
@@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
     // No need for anything the correct sharding is already set.
-    return;
+    return newShardOp ? newShardOp : shardOp;
   }
 
-  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
-  auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
-                              /*annotate_for_users*/ false);
+  if (!newShardOp) {
+    auto shardingOp =
+        builder.create<ShardingOp>(operandValue.getLoc(), sharding);
+    newShardOp =
+        builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
+                                /*annotate_for_users*/ false);
+  }
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
       operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
   if (!shardOp || shardOp.getAnnotateForUsers()) {
-    return;
+    return newShardOp;
   }
 
-  auto newShardOp2 =
-      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
-                              /*annotate_for_users*/ true);
+  auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
+                                             newShardOp.getSharding(),
+                                             /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  return newShardOp;
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
+  ShardOp newShardOp;
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+    newShardOp =
----------------
fschlimb wrote:

`maybeInsertTargetShardingAnnotation` creates the "target" sharding for the given result if it does not yet exist. When the result has multiple uses we do not want to create it more than once. Hence we have to pass it back to the next call to `maybeInsertTargetShardingAnnotation` if there are more than one uses.

https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list