[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)

Chengji Yao llvmlistbot at llvm.org
Mon Nov 27 20:36:25 PST 2023


================
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
----------------
yaochengji wrote:

Could you add a new line of `CHECK-NOT mesh.all_reduce` here

https://github.com/llvm/llvm-project/pull/73150


More information about the Mlir-commits mailing list