[Mlir-commits] [mlir] [mlir][shard, mpi] Allow more than one last axis to be "unsplit" (PR #180754)

Tuomas Kärnä llvmlistbot at llvm.org
Wed Feb 11 01:26:52 PST 2026


================
@@ -52,6 +53,29 @@ func.func @sharding_triplet(
   return %sharded_1 : tensor<2xf32>
 }
 
+// CHECK-LABEL: func.func @unsplit_last_axes_some(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> {
+func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> {
+  %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8>
+  %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8>
----------------
tkarna wrote:

nit: variable naming for clarity, afaik there's no replication here just two different shardings. 
%in2_replicated -> %in2_sharded1
%in2_sharded -> %in2_sharded2

Also applies to the next test.

https://github.com/llvm/llvm-project/pull/180754


More information about the Mlir-commits mailing list