[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)
Frank Schlimbach
llvmlistbot at llvm.org
Mon May 4 06:37:44 PDT 2026
================
@@ -1,3 +1,105 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+shard.grid @grid_ag_alt(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK-NOT: shard.all_gather
+ // CHECK-NOT: shard.all_slice
+ // CHECK: return %arg0 : tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+ %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<2x8xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<2x8xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [0] gather_axis = 0
+ : tensor<4x4xf32> -> tensor<8x4xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+ : tensor<8x4xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag_alt grid_axes = [1] slice_axis = 1
+ : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_identity
+func.func @all_slice_all_gather_identity(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x2xf32> -> tensor<4x4xf32>
+ // CHECK-NOT: shard.all_slice
+ // CHECK-NOT: shard.all_gather
+ // CHECK: return %arg0 : tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_axis
+func.func @all_slice_all_gather_different_axis(
+ %arg0: tensor<4x4xf32>) -> tensor<8x2xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+ : tensor<4x2xf32> -> tensor<8x2xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<8x2xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid
+func.func @all_slice_all_gather_different_grid(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag_alt grid_axes = [1] gather_axis = 1
+ : tensor<4x2xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid_axes
+func.func @all_slice_all_gather_different_grid_axes(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [0] slice_axis = 0
+ : tensor<4x4xf32> -> tensor<2x4xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+ : tensor<2x4xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<4x4xf32>
+}
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
----------------
fschlimb wrote:
```suggestion
```
https://github.com/llvm/llvm-project/pull/193906
More information about the Mlir-commits
mailing list