[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)
Chengji Yao
llvmlistbot at llvm.org
Sun Mar 10 20:40:35 PDT 2024
================
@@ -135,6 +135,36 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
return %2 : tensor<2x16x32xf32>
}
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+ // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+ %arg0: tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+ %arg1: tensor<3x2xf32>,
+ // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+ // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+ // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+ // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
----------------
yaochengji wrote:
Should we respect the `mesh.shard` annotation of %res? Because we explicitly marked it as `<@mesh_2, [[]]>`. But after sharding propagation, it became `<@mesh_2, [[0]]>`.
In my point of view, we'd better insert resharding for `%arg0`. We could add an additional `mesh.shard` whose annotation is aligned with %res's.
https://github.com/llvm/llvm-project/pull/84514
More information about the Mlir-commits
mailing list