[Mlir-commits] [mlir] ea43c3f - [mlir][linalg][shard] Fix andi reduction kind in sharding partition (#192381)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 20 04:40:09 PDT 2026
Author: zackc6
Date: 2026-04-20T13:40:04+02:00
New Revision: ea43c3f7e280133997d98de3eda83fb5c5eb92ae
URL: https://github.com/llvm/llvm-project/commit/ea43c3f7e280133997d98de3eda83fb5c5eb92ae
DIFF: https://github.com/llvm/llvm-project/commit/ea43c3f7e280133997d98de3eda83fb5c5eb92ae.diff
LOG: [mlir][linalg][shard] Fix andi reduction kind in sharding partition (#192381)
linalg sharding now maps arith.andi combiners to bitwise_and (instead of
sum) when creating shard.all_reduce. Adds a shard-partition regression
test that checks the emitted all-reduce uses reduction = bitwise_and for
an andi reduction.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/shard-partition.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index d0165595f9fb6..938608afacc40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -55,7 +55,7 @@ static ReductionKind getReductionKind(Operation *op) {
.Case([](arith::AddIOp op) { return ReductionKind::Sum; })
.Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
.Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
- .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::AndIOp op) { return ReductionKind::BitwiseAnd; })
// TODO: handle signless, signed and unsigned types properly.
// It is assumed that the element type of the collective operands and
// result drive the meaning of the reduction kind, whether it is signed
diff --git a/mlir/test/Dialect/Linalg/shard-partition.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir
index aee97079fb197..b3ee6f8111f91 100644
--- a/mlir/test/Dialect/Linalg/shard-partition.mlir
+++ b/mlir/test/Dialect/Linalg/shard-partition.mlir
@@ -128,6 +128,44 @@ func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding(
// -----
+shard.grid @grid_1d(shape = 3)
+
+// CHECK-LABEL: func @generic_1d_grid_static_tensors_andi_reduction_iterator_sharding
+func.func @generic_1d_grid_static_tensors_andi_reduction_iterator_sharding(
+ // CHECK-SAME: %[[IN:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+ %in: tensor<4x6xi8>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4xi8>
+ %dps_out: tensor<4xi8>
+// CHECK-SAME: -> tensor<4xi8> {
+) -> tensor<4xi8> {
+ %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %in_sharded1 = shard.shard %in to %sharding : tensor<4x6xi8>
+ %in_sharded2 = shard.shard %in_sharded1 to %sharding annotate_for_users : tensor<4x6xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %dps_out_sharded1 = shard.shard %dps_out to %sharding2 : tensor<4xi8>
+ %dps_out_sharded2 = shard.shard %dps_out_sharded1 to %sharding2 annotate_for_users : tensor<4xi8>
+ // CHECK: %[[SHARDED_GENERIC:.*]] = linalg.generic
+ // CHECK-SAME: ins(%[[IN]] : tensor<4x2xi8>)
+ // CHECK: } -> tensor<4xi8>
+ // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_GENERIC]] on @grid_1d grid_axes = [0] reduction = bitwise_and : tensor<4xi8> -> tensor<4xi8>
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]
+ } ins(%in_sharded2 : tensor<4x6xi8>)
+ outs(%dps_out_sharded2 : tensor<4xi8>) {
+ ^bb0(%in_scalar: i8, %out_scalar: i8):
+ %res_scalar = arith.andi %in_scalar, %out_scalar : i8
+ linalg.yield %res_scalar : i8
+ } -> tensor<4xi8>
+ %res_sharded1 = shard.shard %res to %sharding2 : tensor<4xi8>
+ %res_sharded2 = shard.shard %res_sharded1 to %sharding2 annotate_for_users : tensor<4xi8>
+ // CHECK: return %[[ALL_REDUCED]] : tensor<4xi8>
+ return %res_sharded2 : tensor<4xi8>
+}
+
+// -----
+
shard.grid @grid_1d(shape = 4)
// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis
More information about the Mlir-commits
mailing list