[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)

Boian Petkantchin llvmlistbot at llvm.org
Mon May 20 15:14:55 PDT 2024


================
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
   return type;
 }
 
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  builder.setInsertionPointAfterValue(operandValue);
+  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+  if (shardOp && shardOp.getShard() == sharding &&
+      !shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
+    return;
+  }
+
+  auto newShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ false);
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+        return use.getOwner() == operandOp && use.get() == operandValue;
+      });
+
+  if (!shardOp || shardOp.getAnnotateForUsers()) {
+    return;
+  }
+
+  auto newShardOp2 = builder.create<ShardOp>(
+      operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpResult result,
+                                                     OpBuilder &builder) {
+  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+  }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  Operation *operandSrcOp = operandValue.getDefiningOp();
+  bool isBlockArg = !operandSrcOp;
+  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+  if (shardOp && shardOp.getShard() == sharding &&
+      shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
----------------
sogartar wrote:

done

https://github.com/llvm/llvm-project/pull/84514


More information about the Mlir-commits mailing list