[Mlir-commits] [mlir] [mlir][mesh] Fix empty `split_axes` sharding annotation. (PR #108236)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 07:45:48 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matteo Franciolini (mfrancio)

<details>
<summary>Changes</summary>

The `split_axes` attribute is defined as "array attribute of array
attributes". Following the definition, empty `split_axes` values should
not be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:
* crashes in sharding-propagation;
* creation of null MeshShardingAttrs in spmdization;
* non roundtrippable IR.

The patch prevents `split_axes` to become empty by modifying the
`removeTrailingEmptySubArray` such that a minimum size of one is
guaranteed when constructing the attribute, and adds a test that would
crash without the change.



---
Full diff: https://github.com/llvm/llvm-project/pull/108236.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2-1) 
- (modified) mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (+1-1) 
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+20-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 683975bbf215ed..db7b64fda57d7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -98,9 +98,10 @@ inline bool isReductionLoop(utils::IteratorType iType) {
   return iType == utils::IteratorType::reduction;
 }
 
+// Remove empty subarrays of `array` until a minimum lengh of one is reached.
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
-  while (!array.empty() && array.back().empty())
+  while (array.size() > 1 && array.back().empty())
     array.pop_back();
 }
 
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index f8521165e3244e..5297eeb666c1e1 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -18,7 +18,7 @@ func.func @matmul_shard_prallel_axis(
   // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
   // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
   // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
-  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
   // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
   // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
   // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 5b00b45653dbb6..83136f613b020a 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -179,7 +179,7 @@ func.func @resolve_conflicting_annotations(
 ) -> tensor<2x2xf32> {
   // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
   // CHECK-NEXT:  %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]]  : tensor<2x3xf32>
-  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
   // CHECK-NEXT:  %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x3xf32>
   // CHECK-NEXT:  %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<3x2xf32>
   // CHECK-NEXT:  %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x2xf32>
@@ -266,3 +266,22 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
   // CHECK-DAG: return %[[V12]]
   return %6 : tensor<2x4x8xf32>
 }
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = tosa.sigmoid %[[V3]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]]  : tensor<8x16xf32>
+  %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
+  %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V5]]
+  return %2 : tensor<8x16xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/108236


More information about the Mlir-commits mailing list