[flang-commits] [flang] 79dccde - [flang] Change COUNT intrinsic to support different kind logicals
Sacha Ballantyne via flang-commits
flang-commits at lists.llvm.org
Tue Feb 28 04:26:38 PST 2023
Author: Sacha Ballantyne
Date: 2023-02-28T12:26:33Z
New Revision: 79dccded69000d431a3c37b911cfc05a67b14967
URL: https://github.com/llvm/llvm-project/commit/79dccded69000d431a3c37b911cfc05a67b14967
DIFF: https://github.com/llvm/llvm-project/commit/79dccded69000d431a3c37b911cfc05a67b14967.diff
LOG: [flang] Change COUNT intrinsic to support different kind logicals
Previously COUNT would cast the mask input to logical<4> before passing it
to the runtime function, this has been changed to allow different types of logical.
Reviewed By: tblah
Differential Revision: https://reviews.llvm.org/D144867
Added:
Modified:
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
flang/test/Transforms/simplifyintrinsics.fir
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index cd8e8f5839a8b..4a66d437f06c1 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -237,8 +237,9 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
}
// COUNT()
-template <typename T>
+template <typename T, int maskKind>
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
+ using LogicalResult = Type<TypeCategory::Logical, maskKind>;
static_assert(T::category == TypeCategory::Integer);
ActualArguments &arg{ref.arguments()};
if (const Constant<LogicalResult> *mask{arg.empty()
@@ -546,7 +547,18 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
cx->u);
}
} else if (name == "count") {
- return FoldCount<T>(context, std::move(funcRef));
+ int maskKind = args[0]->GetType()->kind();
+ switch (maskKind) {
+ SWITCH_COVERS_ALL_CASES
+ case 1:
+ return FoldCount<T, 1>(context, std::move(funcRef));
+ case 2:
+ return FoldCount<T, 2>(context, std::move(funcRef));
+ case 4:
+ return FoldCount<T, 4>(context, std::move(funcRef));
+ case 8:
+ return FoldCount<T, 8>(context, std::move(funcRef));
+ }
} else if (name == "digits") {
if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
return Expr<T>{common::visit(
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 53ab094ca02fb..9de7ae16d9e4f 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -662,7 +662,7 @@ static void genRuntimeCountBody(fir::FirOpBuilder &builder,
auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType, mlir::Value elem1,
mlir::Value elem2) -> mlir::Value {
- auto zero32 = builder.createIntegerConstant(loc, builder.getI32Type(), 0);
+ auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index 806eeb2bd06ae..8b3086bde9b15 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -1161,6 +1161,54 @@ fir.global linkonce @_QQcl.2E2F746573746661696C2E66393000 constant : !fir.char<1
// CHECK: return %[[RES:.*]] : i64
// CHECK: }
+// -----
+// Ensure count is properly simplified for
diff erent mask kind
+
+func.func @_QP
diff kind(%arg0: !fir.ref<!fir.array<10x!fir.logical<2>>> {fir.bindc_name = "mask"}) -> i32 {
+ %0 = fir.alloca i32 {bindc_name = "
diff kind", uniq_name = "_QF
diff kindE
diff kind"}
+ %c10 = arith.constant 10 : index
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10x!fir.logical<2>>>, !fir.shape<1>) -> !fir.box<!fir.array<10x!fir.logical<2>>>
+ %c0 = arith.constant 0 : index
+ %3 = fir.address_of(@_QQcl.916d74b25894ddf7881ff7f913a677f5) : !fir.ref<!fir.char<1,52>>
+ %c5_i32 = arith.constant 5 : i32
+ %4 = fir.convert %2 : (!fir.box<!fir.array<10x!fir.logical<2>>>) -> !fir.box<none>
+ %5 = fir.convert %3 : (!fir.ref<!fir.char<1,52>>) -> !fir.ref<i8>
+ %6 = fir.convert %c0 : (index) -> i32
+ %7 = fir.call @_FortranACount(%4, %5, %c5_i32, %6) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
+ %8 = fir.convert %7 : (i64) -> i32
+ fir.store %8 to %0 : !fir.ref<i32>
+ %9 = fir.load %0 : !fir.ref<i32>
+ return %9 : i32
+}
+
+// CHECK-LABEL: func.func @_QP
diff kind(
+// CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<10x!fir.logical<2>>> {fir.bindc_name = "mask"}) -> i32 {
+// CHECK: %[[res:.*]] = fir.call @_FortranACountLogical2x1_simplified({{.*}}) fastmath<contract> : (!fir.box<none>) -> i64
+
+// CHECK-LABEL: func.func private @_FortranACountLogical2x1_simplified(
+// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[C_INDEX0:.*]] = arith.constant 0 : index
+// CHECK: %[[ARR_BOX_I16:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi16>>
+// CHECK: %[[IZERO:.*]] = arith.constant 0 : i64
+// CHECK: %[[C_INDEX1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I16]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi16>>, index) -> (index, index, index)
+// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[C_INDEX1]] : index
+// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[C_INDEX0]] to %[[EXTENT]] step %[[C_INDEX1]] iter_args(%[[COUNT:.*]] = %[[IZERO]]) -> (i64) {
+// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I16]], %[[ITER]] : (!fir.box<!fir.array<?xi16>>, index) -> !fir.ref<i16>
+// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i16>
+// CHECK: %[[I16_0:.*]] = arith.constant 0 : i16
+// CHECK: %[[I64_0:.*]] = arith.constant 0 : i64
+// CHECK: %[[I64_1:.*]] = arith.constant 1 : i64
+// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[ITEM_VAL]], %[[I16_0]] : i16
+// CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %[[I64_0]], %[[I64_1]] : i64
+// CHECK: %[[NEW_COUNT:.*]] = arith.addi %[[SELECT]], %[[COUNT]] : i64
+// CHECK: fir.result %[[NEW_COUNT]] : i64
+// CHECK: }
+// CHECK: return %[[RES:.*]] : i64
+// CHECK: }
+
// -----
// Ensure count isn't simplified when given dim argument
More information about the flang-commits
mailing list