[Mlir-commits] [mlir] [mlir][mesh] add support in spmdization for incomplete sharding annotations (PR #82442)

Mehdi Amini llvmlistbot at llvm.org
Wed Feb 21 18:08:55 PST 2024


================
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
                     assert(result.hasOneUse());
                     Operation *userOp = *result.getUsers().begin();
                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    assert(!shardOp.getAnnotateForUsers());
                     return shardOp.getShard();
                   });
   return res;
 }
 
 static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
                  SymbolTableCollection &symbolTableCollection,
                  OpBuilder &builder) {
-  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
-  if (shardOp) {
-    if (!shardOp.getAnnotateForUsers()) {
-      return success();
-    }
-
+  Value targetSpmdValue;
+
+  // Check if 2 shard ops are chained. If not there is no need for resharding
+  // as the source and target shared the same sharding.
+  ShardOp srcShardOp =
+      llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
----------------
joker-eph wrote:

Nit: we should not need the `llvm::` prefix I believe.

https://github.com/llvm/llvm-project/pull/82442


More information about the Mlir-commits mailing list