[flang-commits] [flang] [flang] Use reduction recognition friendly pattern for hlfir.count. (PR #190856)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue Apr 7 14:33:27 PDT 2026
https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/190856
The change is to select between `0` and `1` based on the condition
and then add the result to the current reduction value.
>From 66872d6567d65f8bc1b62cbfd100c8a42b175123 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 7 Apr 2026 14:06:29 -0700
Subject: [PATCH] [flang] Use reduction recognition friendly pattern for
hlfir.count.
The change is to select between `0` and `1` based on the condition
and then add the result to the current reduction value.
---
.../HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 9 +++++----
.../HLFIR/simplify-hlfir-intrinsics-count.fir | 16 ++++++++--------
2 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 7ff9dc61110d3..26c5b63cb05b6 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -1128,12 +1128,13 @@ class CountAsElementalConverter
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
mlir::Value cond =
builder.createConvert(loc, builder.getI1Type(), elementValue);
+ mlir::Value zero =
+ builder.createIntegerConstant(loc, getResultElementType(), 0);
mlir::Value one =
builder.createIntegerConstant(loc, getResultElementType(), 1);
- mlir::Value add1 =
- mlir::arith::AddIOp::create(builder, loc, currentValue[0], one);
- return {mlir::arith::SelectOp::create(builder, loc, cond, add1,
- currentValue[0])};
+ mlir::Value addend =
+ mlir::arith::SelectOp::create(builder, loc, cond, one, zero);
+ return {mlir::arith::AddIOp::create(builder, loc, currentValue[0], addend)};
}
};
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir
index 44594c646a368..3da22f16233bc 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-count.fir
@@ -16,8 +16,8 @@ func.func @test_total_expr(%arg0: !hlfir.expr<?x?x!fir.logical<4>>) -> i32 {
// CHECK: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (i32) {
// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_8]] : (!hlfir.expr<?x?x!fir.logical<4>>, index, index) -> !fir.logical<4>
// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (!fir.logical<4>) -> i1
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : i32
-// CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_14]], %[[VAL_15]], %[[VAL_12]] : i32
+// CHECK: %[[VAL_15:.*]] = arith.select %[[VAL_14]], %[[VAL_1]], %[[VAL_3]] : i32
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_15]] : i32
// CHECK: fir.result %[[VAL_16]] : i32
// CHECK: }
// CHECK: fir.result %[[VAL_10]] : i32
@@ -45,8 +45,8 @@ func.func @test_partial_expr(%arg0: !hlfir.expr<?x?x?x!fir.logical<1>>) -> !hlfi
// CHECK: %[[VAL_12:.*]] = fir.do_loop %[[VAL_13:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] unordered iter_args(%[[VAL_14:.*]] = %[[VAL_3]]) -> (i16) {
// CHECK: %[[VAL_15:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_13]], %[[VAL_11]] : (!hlfir.expr<?x?x?x!fir.logical<1>>, index, index, index) -> !fir.logical<1>
// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (!fir.logical<1>) -> i1
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_1]] : i16
-// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_16]], %[[VAL_17]], %[[VAL_14]] : i16
+// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_16]], %[[VAL_1]], %[[VAL_3]] : i16
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_14]], %[[VAL_17]] : i16
// CHECK: fir.result %[[VAL_18]] : i16
// CHECK: }
// CHECK: hlfir.yield_element %[[VAL_12]] : i16
@@ -77,8 +77,8 @@ func.func @test_total_var(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>) -> i
// CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_16]], %[[VAL_18]]) : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index, index) -> !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_19]] : !fir.ref<!fir.logical<4>>
// CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (!fir.logical<4>) -> i1
-// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : i32
-// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_12]] : i32
+// CHECK: %[[VAL_22:.*]] = arith.select %[[VAL_21]], %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_12]], %[[VAL_22]] : i32
// CHECK: fir.result %[[VAL_23]] : i32
// CHECK: }
// CHECK: fir.result %[[VAL_10]] : i32
@@ -117,8 +117,8 @@ func.func @test_partial_var(%arg0: !fir.box<!fir.array<?x?x?x!fir.logical<2>>>)
// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_20]], %[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x?x?x!fir.logical<2>>>, index, index, index) -> !fir.ref<!fir.logical<2>>
// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<!fir.logical<2>>
// CHECK: %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (!fir.logical<2>) -> i1
-// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_15]], %[[VAL_1]] : i64
-// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_15]] : i64
+// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_27]], %[[VAL_1]], %[[VAL_2]] : i64
+// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_15]], %[[VAL_28]] : i64
// CHECK: fir.result %[[VAL_29]] : i64
// CHECK: }
// CHECK: hlfir.yield_element %[[VAL_13]] : i64
More information about the flang-commits
mailing list