[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)

Frank Schlimbach llvmlistbot at llvm.org
Mon May 4 06:37:44 PDT 2026


================
@@ -1,3 +1,105 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+shard.grid @grid_ag_alt(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK-NOT: shard.all_gather
+  // CHECK-NOT: shard.all_slice
+  // CHECK: return %arg0 : tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+    %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<2x8xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<2x8xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [0] gather_axis = 0
+    : tensor<4x4xf32> -> tensor<8x4xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+    : tensor<8x4xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag_alt grid_axes = [1] slice_axis = 1
+    : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_identity
+func.func @all_slice_all_gather_identity(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x2xf32> -> tensor<4x4xf32>
+  // CHECK-NOT: shard.all_slice
+  // CHECK-NOT: shard.all_gather
+  // CHECK: return %arg0 : tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_axis
+func.func @all_slice_all_gather_different_axis(
+    %arg0: tensor<4x4xf32>) -> tensor<8x2xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+    : tensor<4x2xf32> -> tensor<8x2xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<8x2xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid
+func.func @all_slice_all_gather_different_grid(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag_alt grid_axes = [1] gather_axis = 1
+    : tensor<4x2xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid_axes
+func.func @all_slice_all_gather_different_grid_axes(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [0] slice_axis = 0
+    : tensor<4x4xf32> -> tensor<2x4xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+    : tensor<2x4xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<4x4xf32>
+}
 // RUN: mlir-opt -shard-simplify %s | FileCheck %s
----------------
fschlimb wrote:

```suggestion
```

https://github.com/llvm/llvm-project/pull/193906


More information about the Mlir-commits mailing list