[Mlir-commits] [mlir] ea43c3f - [mlir][linalg][shard] Fix andi reduction kind in sharding partition (#192381)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 20 04:40:09 PDT 2026


Author: zackc6
Date: 2026-04-20T13:40:04+02:00
New Revision: ea43c3f7e280133997d98de3eda83fb5c5eb92ae

URL: https://github.com/llvm/llvm-project/commit/ea43c3f7e280133997d98de3eda83fb5c5eb92ae
DIFF: https://github.com/llvm/llvm-project/commit/ea43c3f7e280133997d98de3eda83fb5c5eb92ae.diff

LOG: [mlir][linalg][shard] Fix andi reduction kind in sharding partition (#192381)

linalg sharding now maps arith.andi combiners to bitwise_and (instead of
sum) when creating shard.all_reduce. Adds a shard-partition regression
test that checks the emitted all-reduce uses reduction = bitwise_and for
an andi reduction.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/shard-partition.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index d0165595f9fb6..938608afacc40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -55,7 +55,7 @@ static ReductionKind getReductionKind(Operation *op) {
       .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
       .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
       .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
-      .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::AndIOp op) { return ReductionKind::BitwiseAnd; })
       // TODO: handle signless, signed and unsigned types properly.
       // It is assumed that the element type of the collective operands and
       // result drive the meaning of the reduction kind, whether it is signed

diff  --git a/mlir/test/Dialect/Linalg/shard-partition.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir
index aee97079fb197..b3ee6f8111f91 100644
--- a/mlir/test/Dialect/Linalg/shard-partition.mlir
+++ b/mlir/test/Dialect/Linalg/shard-partition.mlir
@@ -128,6 +128,44 @@ func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding(
 
 // -----
 
+shard.grid @grid_1d(shape = 3)
+
+// CHECK-LABEL: func @generic_1d_grid_static_tensors_andi_reduction_iterator_sharding
+func.func @generic_1d_grid_static_tensors_andi_reduction_iterator_sharding(
+  // CHECK-SAME: %[[IN:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+  %in: tensor<4x6xi8>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4xi8>
+  %dps_out: tensor<4xi8>
+// CHECK-SAME: -> tensor<4xi8> {
+) -> tensor<4xi8> {
+  %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %in_sharded1 = shard.shard %in to %sharding : tensor<4x6xi8>
+  %in_sharded2 = shard.shard %in_sharded1 to %sharding annotate_for_users : tensor<4x6xi8>
+  %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %dps_out_sharded1 = shard.shard %dps_out to %sharding2 : tensor<4xi8>
+  %dps_out_sharded2 = shard.shard %dps_out_sharded1 to %sharding2 annotate_for_users : tensor<4xi8>
+  // CHECK: %[[SHARDED_GENERIC:.*]] = linalg.generic
+  // CHECK-SAME: ins(%[[IN]] : tensor<4x2xi8>)
+  // CHECK: } -> tensor<4xi8>
+  // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_GENERIC]] on @grid_1d grid_axes = [0] reduction = bitwise_and : tensor<4xi8> -> tensor<4xi8>
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]
+    } ins(%in_sharded2 : tensor<4x6xi8>)
+      outs(%dps_out_sharded2 : tensor<4xi8>) {
+    ^bb0(%in_scalar: i8, %out_scalar: i8):
+      %res_scalar = arith.andi %in_scalar, %out_scalar : i8
+      linalg.yield %res_scalar : i8
+    } -> tensor<4xi8>
+  %res_sharded1 = shard.shard %res to %sharding2 : tensor<4xi8>
+  %res_sharded2 = shard.shard %res_sharded1 to %sharding2 annotate_for_users : tensor<4xi8>
+  // CHECK: return %[[ALL_REDUCED]] : tensor<4xi8>
+  return %res_sharded2 : tensor<4xi8>
+}
+
+// -----
+
 shard.grid @grid_1d(shape = 4)
 
 // CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis


        


More information about the Mlir-commits mailing list