[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)

Chengji Yao llvmlistbot at llvm.org
Tue May 21 19:14:12 PDT 2024


================
@@ -135,6 +135,36 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
   return %2 : tensor<2x16x32xf32>
 }
 
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+  // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+  %arg0: tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+  %arg1: tensor<3x2xf32>,
+  // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+  // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+  // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+  // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
----------------
yaochengji wrote:

Yeah, this is what I mean.

https://github.com/llvm/llvm-project/pull/84514


More information about the Mlir-commits mailing list