[Mlir-commits] [mlir] [mlir][mesh] add support in spmdization for incomplete sharding annotations (PR #82442)
Boian Petkantchin
llvmlistbot at llvm.org
Thu Feb 22 08:26:42 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/82442
>From 91ff453870cdf152b99506348a7f334564af69b3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 20 Feb 2024 16:05:52 -0800
Subject: [PATCH 1/4] [mlir][mesh] add support in spmdization for incomplete
sharding annotations
Don't require that `mesh.shard` operations come in pairs.
If there is only a single `mesh.shard` operation we assume that the producer
result and consumer operand have the same sharding.
---
.../Dialect/Mesh/Transforms/Spmdization.cpp | 56 +++++++++++++------
mlir/test/Dialect/Mesh/spmdization.mlir | 14 +++++
2 files changed, 54 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de048769b..73cae225ea69d9 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- assert(shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
@@ -615,34 +614,59 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- assert(!shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
}
+ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
+ Operation *srcOp = targetShardOp.getOperand().getDefiningOp();
+ if (!srcOp) {
+ return ShardOp();
+ }
+ ShardOp srcShardOp =
+ llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
+ if (!srcShardOp) {
+ return ShardOp();
+ }
+
+ return srcShardOp;
+}
+
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
- if (shardOp) {
- if (!shardOp.getAnnotateForUsers()) {
- return success();
- }
+ Value targetSpmdValue;
+ // Check if 2 shard ops are chained. If not there is no need for resharding
+ // as the source and target shared the same sharding.
+ ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
+ if (!srcShardOp) {
+ targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+ } else {
// Insert resharding.
- ShardOp srcShardOp =
- llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
- assert(!srcShardOp.getAnnotateForUsers());
+ assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
- Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
- symbolTableCollection);
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
- return success();
+ targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ symbolTableCollection);
+ }
+
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
+ builder);
}
SmallVector<Value> spmdizedOperands;
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029dfe64ae..258c3786e3518c 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}
+
+// // CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}
>From 04e1038ac9e1c01fb70471e652bbca67a8774ff1 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 20 Feb 2024 16:21:30 -0800
Subject: [PATCH 2/4] Make one function static
---
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 73cae225ea69d9..2ba3b225669a22 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -619,7 +619,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
return res;
}
-ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
+static ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
Operation *srcOp = targetShardOp.getOperand().getDefiningOp();
if (!srcOp) {
return ShardOp();
>From 6502820f2d37687f58c98af14f4f49ca773e35d0 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 21 Feb 2024 06:35:24 -0800
Subject: [PATCH 3/4] Address PR comments
---
.../lib/Dialect/Mesh/Transforms/Spmdization.cpp | 17 ++---------------
mlir/test/Dialect/Mesh/spmdization.mlir | 2 +-
2 files changed, 3 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 2ba3b225669a22..2dace232f6aced 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -619,20 +619,6 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
return res;
}
-static ShardOp getSourceShardOpOrNull(ShardOp targetShardOp) {
- Operation *srcOp = targetShardOp.getOperand().getDefiningOp();
- if (!srcOp) {
- return ShardOp();
- }
- ShardOp srcShardOp =
- llvm::dyn_cast<ShardOp>(targetShardOp.getOperand().getDefiningOp());
- if (!srcShardOp) {
- return ShardOp();
- }
-
- return srcShardOp;
-}
-
static LogicalResult
spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
@@ -641,7 +627,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
- ShardOp srcShardOp = getSourceShardOpOrNull(shardOp);
+ ShardOp srcShardOp =
+ llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
if (!srcShardOp) {
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
} else {
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 258c3786e3518c..572d3eb55eaaae 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -128,7 +128,7 @@ func.func @multiple_chained_ops(
return %7 : tensor<2xi8>
}
-// // CHECK-LABEL: func @incomplete_sharding
+// CHECK-LABEL: func @incomplete_sharding
func.func @incomplete_sharding(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
%arg0: tensor<8x16xf32>
>From 7ca5d765a0d87d913bf708585da80463c6ac9093 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 22 Feb 2024 08:26:02 -0800
Subject: [PATCH 4/4] Remove llvm:: in llvm::dyn_cast_or_null
---
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 2dace232f6aced..c4d8b0b15e462c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -628,7 +628,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
ShardOp srcShardOp =
- llvm::dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
+ dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
if (!srcShardOp) {
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
} else {
More information about the Mlir-commits
mailing list