[Mlir-commits] [mlir] [mlir][mesh] Fix wrong argument passed to targetShardingInUnsplitLast… (PR #95059)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 10 16:37:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Arda Unal (ardaunal)

<details>
<summary>Changes</summary>

…Axis

In unsplitLastAxisInResharding, wrong argument was passed when calling targetShardingInUnsplitLastAxis.There weren't any tests to uncover this. I added one in mesh-spmdization.mlir for Linalg and one in resharding-spmdization.mlir for Mesh dialects.

---
Full diff: https://github.com/llvm/llvm-project/pull/95059.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+1-1) 
- (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+35) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index f3e4b15aec118..4ecc897103af7 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -266,7 +266,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   builder.setInsertionPointAfterValue(sourceShard);
 
   MeshShardingAttr targetSharding =
-      targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+      targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
       sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
   Value allGatherResult = builder.create<AllGatherOp>(
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index bd56c801283b1..52f352cfedd8e 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
   // CHECK:      return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 4)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
+func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
+  %in1: tensor<4x6xi8>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
+  %in2: tensor<6x8xi8>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+  %dps_out: tensor<4x8xi8>
+  // CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+  %in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
+  %in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
+  // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
+  %in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
+  %in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
+  // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
+  %dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
+  %dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
+  // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
+  // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
+  // CHECK-SAME: -> tensor<4x2xi8>
+  %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
+      outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
+  %res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
+  %res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
+  // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
+  return %res_replicated : tensor<4x8xi8>
+}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index ba05306598bcc..b3e305135ad8b 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -96,6 +96,19 @@ func.func @unshard_static_axis(
   return %1 : tensor<10x14xf32>
 }
 
+// CHECK-LABEL: func @unshard_static_last_axis
+func.func @unshard_static_last_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
 // CHECK-LABEL: func @unshard_dynamic_axis
 func.func @unshard_dynamic_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/95059


More information about the Mlir-commits mailing list