[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)
Boian Petkantchin
llvmlistbot at llvm.org
Mon May 20 15:14:55 PDT 2024
================
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
return type;
}
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ builder.setInsertionPointAfterValue(operandValue);
+ ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+ if (shardOp && shardOp.getShard() == sharding &&
+ !shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
+ return;
+ }
+
+ auto newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ false);
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ return use.getOwner() == operandOp && use.get() == operandValue;
+ });
+
+ if (!shardOp || shardOp.getAnnotateForUsers()) {
+ return;
+ }
+
+ auto newShardOp2 = builder.create<ShardOp>(
+ operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+ rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpResult result,
+ OpBuilder &builder) {
+ for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+ maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ Operation *operandSrcOp = operandValue.getDefiningOp();
+ bool isBlockArg = !operandSrcOp;
+ ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+ if (shardOp && shardOp.getShard() == sharding &&
+ shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
----------------
sogartar wrote:
done
https://github.com/llvm/llvm-project/pull/84514
More information about the Mlir-commits
mailing list