[Mlir-commits] [mlir] [mlir][mesh] Fix empty `split_axes` sharding annotation. (PR #108236)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 07:45:48 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Matteo Franciolini (mfrancio)
<details>
<summary>Changes</summary>
The `split_axes` attribute is defined as "array attribute of array
attributes". Following the definition, empty `split_axes` values should
not be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:
* crashes in sharding-propagation;
* creation of null MeshShardingAttrs in spmdization;
* non roundtrippable IR.
The patch prevents `split_axes` to become empty by modifying the
`removeTrailingEmptySubArray` such that a minimum size of one is
guaranteed when constructing the attribute, and adds a test that would
crash without the change.
---
Full diff: https://github.com/llvm/llvm-project/pull/108236.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2-1)
- (modified) mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (+1-1)
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+20-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 683975bbf215ed..db7b64fda57d7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -98,9 +98,10 @@ inline bool isReductionLoop(utils::IteratorType iType) {
return iType == utils::IteratorType::reduction;
}
+// Remove empty subarrays of `array` until a minimum lengh of one is reached.
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
- while (!array.empty() && array.back().empty())
+ while (array.size() > 1 && array.back().empty())
array.pop_back();
}
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index f8521165e3244e..5297eeb666c1e1 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -18,7 +18,7 @@ func.func @matmul_shard_prallel_axis(
// CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
// CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
// CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
// CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
// CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 5b00b45653dbb6..83136f613b020a 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -179,7 +179,7 @@ func.func @resolve_conflicting_annotations(
) -> tensor<2x2xf32> {
// CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
- // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
// CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
// CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
@@ -266,3 +266,22 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
// CHECK-DAG: return %[[V12]]
return %6 : tensor<2x4x8xf32>
}
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] : tensor<8x16xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
+ %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V5]]
+ return %2 : tensor<8x16xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/108236
More information about the Mlir-commits
mailing list