[flang-commits] [flang] 79dccde - [flang] Change COUNT intrinsic to support different kind logicals

Sacha Ballantyne via flang-commits flang-commits at lists.llvm.org
Tue Feb 28 04:26:38 PST 2023


Author: Sacha Ballantyne
Date: 2023-02-28T12:26:33Z
New Revision: 79dccded69000d431a3c37b911cfc05a67b14967

URL: https://github.com/llvm/llvm-project/commit/79dccded69000d431a3c37b911cfc05a67b14967
DIFF: https://github.com/llvm/llvm-project/commit/79dccded69000d431a3c37b911cfc05a67b14967.diff

LOG: [flang] Change COUNT intrinsic to support different kind logicals

Previously COUNT would cast the mask input to logical<4> before passing it
to the runtime function, this has been changed to allow different types of logical.

Reviewed By: tblah

Differential Revision: https://reviews.llvm.org/D144867

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
    flang/test/Transforms/simplifyintrinsics.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index cd8e8f5839a8b..4a66d437f06c1 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -237,8 +237,9 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
 }
 
 // COUNT()
-template <typename T>
+template <typename T, int maskKind>
 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
+  using LogicalResult = Type<TypeCategory::Logical, maskKind>;
   static_assert(T::category == TypeCategory::Integer);
   ActualArguments &arg{ref.arguments()};
   if (const Constant<LogicalResult> *mask{arg.empty()
@@ -546,7 +547,18 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
           cx->u);
     }
   } else if (name == "count") {
-    return FoldCount<T>(context, std::move(funcRef));
+    int maskKind = args[0]->GetType()->kind();
+    switch (maskKind) {
+      SWITCH_COVERS_ALL_CASES
+    case 1:
+      return FoldCount<T, 1>(context, std::move(funcRef));
+    case 2:
+      return FoldCount<T, 2>(context, std::move(funcRef));
+    case 4:
+      return FoldCount<T, 4>(context, std::move(funcRef));
+    case 8:
+      return FoldCount<T, 8>(context, std::move(funcRef));
+    }
   } else if (name == "digits") {
     if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
       return Expr<T>{common::visit(

diff  --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 53ab094ca02fb..9de7ae16d9e4f 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -662,7 +662,7 @@ static void genRuntimeCountBody(fir::FirOpBuilder &builder,
   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
                       mlir::Type elementType, mlir::Value elem1,
                       mlir::Value elem2) -> mlir::Value {
-    auto zero32 = builder.createIntegerConstant(loc, builder.getI32Type(), 0);
+    auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
     auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
     auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
 

diff  --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index 806eeb2bd06ae..8b3086bde9b15 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -1161,6 +1161,54 @@ fir.global linkonce @_QQcl.2E2F746573746661696C2E66393000 constant : !fir.char<1
 // CHECK:           return %[[RES:.*]] : i64
 // CHECK:         }
 
+// -----
+// Ensure count is properly simplified for 
diff erent mask kind
+
+func.func @_QP
diff kind(%arg0: !fir.ref<!fir.array<10x!fir.logical<2>>> {fir.bindc_name = "mask"}) -> i32 {
+  %0 = fir.alloca i32 {bindc_name = "
diff kind", uniq_name = "_QF
diff kindE
diff kind"}
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10x!fir.logical<2>>>, !fir.shape<1>) -> !fir.box<!fir.array<10x!fir.logical<2>>>
+  %c0 = arith.constant 0 : index
+  %3 = fir.address_of(@_QQcl.916d74b25894ddf7881ff7f913a677f5) : !fir.ref<!fir.char<1,52>>
+  %c5_i32 = arith.constant 5 : i32
+  %4 = fir.convert %2 : (!fir.box<!fir.array<10x!fir.logical<2>>>) -> !fir.box<none>
+  %5 = fir.convert %3 : (!fir.ref<!fir.char<1,52>>) -> !fir.ref<i8>
+  %6 = fir.convert %c0 : (index) -> i32
+  %7 = fir.call @_FortranACount(%4, %5, %c5_i32, %6) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
+  %8 = fir.convert %7 : (i64) -> i32
+  fir.store %8 to %0 : !fir.ref<i32>
+  %9 = fir.load %0 : !fir.ref<i32>
+  return %9 : i32
+}
+
+// CHECK-LABEL:   func.func @_QP
diff kind(
+// CHECK-SAME:                           %[[A:.*]]: !fir.ref<!fir.array<10x!fir.logical<2>>> {fir.bindc_name = "mask"}) -> i32 {
+// CHECK:           %[[res:.*]] = fir.call @_FortranACountLogical2x1_simplified({{.*}}) fastmath<contract> : (!fir.box<none>) -> i64
+
+// CHECK-LABEL:   func.func private @_FortranACountLogical2x1_simplified(
+// CHECK-SAME:                                                           %[[ARR:.*]]: !fir.box<none>) -> i64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[C_INDEX0:.*]] = arith.constant 0 : index
+// CHECK:           %[[ARR_BOX_I16:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi16>>
+// CHECK:           %[[IZERO:.*]] = arith.constant 0 : i64
+// CHECK:           %[[C_INDEX1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I16]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi16>>, index) -> (index, index, index)
+// CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[C_INDEX1]] : index
+// CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[C_INDEX0]] to %[[EXTENT]] step %[[C_INDEX1]] iter_args(%[[COUNT:.*]] = %[[IZERO]]) -> (i64) {
+// CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I16]], %[[ITER]] : (!fir.box<!fir.array<?xi16>>, index) -> !fir.ref<i16>
+// CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i16>
+// CHECK:             %[[I16_0:.*]] = arith.constant 0 : i16
+// CHECK:             %[[I64_0:.*]] = arith.constant 0 : i64
+// CHECK:             %[[I64_1:.*]] = arith.constant 1 : i64
+// CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[ITEM_VAL]], %[[I16_0]] : i16
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMP]], %[[I64_0]], %[[I64_1]] : i64
+// CHECK:             %[[NEW_COUNT:.*]] = arith.addi %[[SELECT]], %[[COUNT]] : i64
+// CHECK:             fir.result %[[NEW_COUNT]] : i64
+// CHECK:           }
+// CHECK:           return %[[RES:.*]] : i64
+// CHECK:         }
+
 // -----
 // Ensure count isn't simplified when given dim argument
 


        


More information about the flang-commits mailing list