[Mlir-commits] [mlir] [mlir][mesh] add support in spmdization for incomplete sharding annotations (PR #82442)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 20 16:18:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

Don't require that `mesh.shard` operations come in pairs. If there is only a single `mesh.shard` operation we assume that the producer result and consumer operand have the same sharding.

---
Full diff: https://github.com/llvm/llvm-project/pull/82442.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+38-15) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de048769b..287db5dd08c5fd 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
     Operation *definingOp = operand.getDefiningOp();
     assert(definingOp);
     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
-    assert(shardOp.getAnnotateForUsers());
     return shardOp.getShard();
   });
   return res;
@@ -615,34 +614,58 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
                     assert(result.hasOneUse());
                     Operation *userOp = *result.getUsers().begin();
                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    assert(!shardOp.getAnnotateForUsers());
                     return shardOp.getShard();
                   });
   return res;
 }
 
+ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
+  Operation* srcOp = targetShardOp.getOperand().getDefiningOp();
+  if (!srcOp) {
+    return ShardOp();
+  }
+  ShardOp srcShardOp =
+      llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
+  if (!srcShardOp) {
+    return ShardOp();
+  }
+
+  return srcShardOp;
+}
+
 static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
                  SymbolTableCollection &symbolTableCollection,
                  OpBuilder &builder) {
-  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
-  if (shardOp) {
-    if (!shardOp.getAnnotateForUsers()) {
-      return success();
-    }
+  Value targetSpmdValue;
 
+  // Check if 2 shard ops are chained. If not there is no need for resharding
+  // as the source and target shared the same sharding.
+  ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
+  if (!srcShardOp) {
+    targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+  } else {
     // Insert resharding.
-    ShardOp srcShardOp =
-        llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
-    assert(!srcShardOp.getAnnotateForUsers());
+    assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
     TypedValue<ShapedType> srcSpmdValue =
         spmdizationMap.lookup(srcShardOp.getOperand())
             .cast<TypedValue<ShapedType>>();
-    Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+    targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                                     symbolTableCollection);
-    assert(!spmdizationMap.contains(shardOp.getResult()));
-    spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
-    return success();
+  }
+
+  assert(!spmdizationMap.contains(shardOp.getResult()));
+  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+  return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+                 SymbolTableCollection &symbolTableCollection,
+                 OpBuilder &builder) {
+  ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+  if (shardOp) {
+    return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, builder);
   }
 
   SmallVector<Value> spmdizedOperands;
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029dfe64ae..258c3786e3518c 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
   // CHECK: return %[[RESHARD3]] : tensor<1xi8>
   return %7 : tensor<2xi8>
 }
+
+// // CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+  %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<4x16xf32>
+  return %2 : tensor<8x16xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/82442


More information about the Mlir-commits mailing list