[flang-commits] [flang] 7d2e198 - [flang] Add Count to simplified intrinsics

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Fri Jan 27 08:30:51 PST 2023


Author: Sacha Ballantyne
Date: 2023-01-27T16:30:11Z
New Revision: 7d2e198729df14a7e025d44ae8aa21ce14be9baa

URL: https://github.com/llvm/llvm-project/commit/7d2e198729df14a7e025d44ae8aa21ce14be9baa
DIFF: https://github.com/llvm/llvm-project/commit/7d2e198729df14a7e025d44ae8aa21ce14be9baa.diff

LOG: [flang] Add Count to simplified intrinsics

This patch adds a simplfiied version of count for the simplify intrinsics pass, allowing the function to be inlined.

This was done specifically to help improve performance for exchange2, and provides a ~12% performance increase.

Reviewed By: vzakhari, Leporacanthicus

Differential Revision: https://reviews.llvm.org/D142209

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
    flang/test/Transforms/simplifyintrinsics.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 25f161ccdf7ed..c1f7f39ad356a 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -73,14 +73,22 @@ class SimplifyIntrinsicsPass
   void getDependentDialects(mlir::DialectRegistry &registry) const override;
 
 private:
-  /// Helper function to replace a reduction type of call with its
+  /// Helper functions to replace a reduction type of call with its
   /// simplified form. The actual function is generated using a callback
   /// function.
   /// \p call is the call to be replaced
   /// \p kindMap is used to create FIROpBuilder
   /// \p genBodyFunc is the callback that builds the replacement function
-  void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
-                         GenReductionBodyTy genBodyFunc);
+  void simplifyIntOrFloatReduction(fir::CallOp call,
+                                   const fir::KindMapping &kindMap,
+                                   GenReductionBodyTy genBodyFunc);
+  void simplifyLogicalReduction(fir::CallOp call,
+                                const fir::KindMapping &kindMap,
+                                GenReductionBodyTy genBodyFunc);
+  void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
+                             GenReductionBodyTy genBodyFunc,
+                             fir::FirOpBuilder &builder,
+                             const mlir::StringRef &basename);
 };
 
 } // namespace
@@ -131,17 +139,18 @@ using InitValGeneratorTy = llvm::function_ref<mlir::Value(
 
 /// Generate the reduction loop into \p funcOp.
 ///
+/// \p elementType is the type of the elements in the input array,
+///    which may be 
diff erent to the return type.
 /// \p initVal is a function, called to get the initial value for
 ///    the reduction value
 /// \p genBody is called to fill in the actual reduciton operation
 ///    for example add for SUM, MAX for MAXVAL, etc.
 /// \p rank is the rank of the input argument.
-static void genReductionLoop(fir::FirOpBuilder &builder,
+static void genReductionLoop(fir::FirOpBuilder &builder, mlir::Type elementType,
                              mlir::func::FuncOp &funcOp,
                              InitValGeneratorTy initVal,
                              BodyOpGeneratorTy genBody, unsigned rank) {
   auto loc = mlir::UnknownLoc::get(builder.getContext());
-  mlir::Type elementType = funcOp.getResultTypes()[0];
   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
 
   mlir::IndexType idxTy = builder.getIndexType();
@@ -156,7 +165,8 @@ static void genReductionLoop(fir::FirOpBuilder &builder,
   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
-  mlir::Value init = initVal(builder, loc, elementType);
+  mlir::Type resultType = funcOp.getResultTypes()[0];
+  mlir::Value init = initVal(builder, loc, resultType);
 
   llvm::SmallVector<mlir::Value, 15> bounds;
 
@@ -265,7 +275,9 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder,
     return {};
   };
 
-  genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
+  mlir::Type elementType = funcOp.getResultTypes()[0];
+
+  genReductionLoop(builder, elementType, funcOp, zero, genBodyOp, rank);
 }
 
 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
@@ -293,7 +305,38 @@ static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
     llvm_unreachable("unsupported type");
     return {};
   };
-  genReductionLoop(builder, funcOp, init, genBodyOp, rank);
+
+  mlir::Type elementType = funcOp.getResultTypes()[0];
+
+  genReductionLoop(builder, elementType, funcOp, init, genBodyOp, rank);
+}
+
+static void genRuntimeCountBody(fir::FirOpBuilder &builder,
+                                mlir::func::FuncOp &funcOp, unsigned rank) {
+  auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
+                 mlir::Type elementType) {
+    unsigned bits = elementType.getIntOrFloatBitWidth();
+    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, elementType, zeroInt);
+  };
+
+  auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
+                      mlir::Type elementType, mlir::Value elem1,
+                      mlir::Value elem2) -> mlir::Value {
+    auto zero32 = builder.createIntegerConstant(loc, builder.getI32Type(), 0);
+    auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
+    auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
+
+    auto compare = builder.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
+    auto select =
+        builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
+    return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
+  };
+
+  mlir::Type elementType = builder.getI32Type();
+
+  genReductionLoop(builder, elementType, funcOp, zero, genBodyOp, rank);
 }
 
 /// Generate function type for the simplified version of RTNAME(DotProduct)
@@ -526,58 +569,99 @@ static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
   } while (true);
 }
 
-void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
-                                               const fir::KindMapping &kindMap,
-                                               GenReductionBodyTy genBodyFunc) {
-  mlir::SymbolRefAttr callee = call.getCalleeAttr();
-  mlir::Operation::operand_range args = call.getArgs();
+void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
+    fir::CallOp call, const fir::KindMapping &kindMap,
+    GenReductionBodyTy genBodyFunc) {
   // args[1] and args[2] are source filename and line number, ignored.
+  mlir::Operation::operand_range args = call.getArgs();
+
   const mlir::Value &dim = args[3];
   const mlir::Value &mask = args[4];
   // dim is zero when it is absent, which is an implementation
   // detail in the runtime library.
+
   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
   unsigned rank = getDimCount(args[0]);
-  if (dimAndMaskAbsent && rank > 0) {
-    mlir::Location loc = call.getLoc();
-    fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
-    std::string fmfString{getFastMathFlagsString(builder)};
-
-    // Support only floating point and integer results now.
-    mlir::Type resultType = call.getResult(0).getType();
-    if (!resultType.isa<mlir::FloatType>() &&
-        !resultType.isa<mlir::IntegerType>())
-      return;
-
-    auto argType = getArgElementType(args[0]);
-    if (!argType)
-      return;
-    assert(*argType == resultType &&
-           "Argument/result types mismatch in reduction");
-
-    auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
-      return genNoneBoxType(builder, resultType);
-    };
-    auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
-                                               mlir::func::FuncOp &funcOp) {
-      genBodyFunc(builder, funcOp, rank);
-    };
-    // Mangle the function name with the rank value as "x<rank>".
-    std::string funcName =
-        (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
-         mlir::Twine{rank} +
-         // We must mangle the generated function name with FastMathFlags
-         // value.
-         (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
-            .str();
-    mlir::func::FuncOp newFunc =
-        getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
-    auto newCall =
-        builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
-    call->replaceAllUsesWith(newCall.getResults());
-    call->dropAllReferences();
-    call->erase();
-  }
+
+  if (!(dimAndMaskAbsent && rank > 0))
+    return;
+
+  mlir::Type resultType = call.getResult(0).getType();
+
+  if (!resultType.isa<mlir::FloatType>() &&
+      !resultType.isa<mlir::IntegerType>())
+    return;
+
+  auto argType = getArgElementType(args[0]);
+  if (!argType)
+    return;
+  assert(*argType == resultType &&
+         "Argument/result types mismatch in reduction");
+
+  mlir::SymbolRefAttr callee = call.getCalleeAttr();
+
+  fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
+  std::string fmfString{getFastMathFlagsString(builder)};
+  std::string funcName =
+      (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+       mlir::Twine{rank} +
+       // We must mangle the generated function name with FastMathFlags
+       // value.
+       (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
+          .str();
+
+  simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName);
+}
+
+void SimplifyIntrinsicsPass::simplifyLogicalReduction(
+    fir::CallOp call, const fir::KindMapping &kindMap,
+    GenReductionBodyTy genBodyFunc) {
+
+  mlir::Operation::operand_range args = call.getArgs();
+  const mlir::Value &dim = args[3];
+
+  if (!isZero(dim))
+    return;
+
+  unsigned rank = getDimCount(args[0]);
+  mlir::SymbolRefAttr callee = call.getCalleeAttr();
+
+  fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
+  std::string funcName =
+      (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+       mlir::Twine{rank})
+          .str();
+
+  simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName);
+}
+
+void SimplifyIntrinsicsPass::simplifyReductionBody(
+    fir::CallOp call, const fir::KindMapping &kindMap,
+    GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
+    const mlir::StringRef &funcName) {
+
+  mlir::Operation::operand_range args = call.getArgs();
+
+  mlir::Type resultType = call.getResult(0).getType();
+  unsigned rank = getDimCount(args[0]);
+
+  mlir::Location loc = call.getLoc();
+
+  auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
+    return genNoneBoxType(builder, resultType);
+  };
+  auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+                                             mlir::func::FuncOp &funcOp) {
+    genBodyFunc(builder, funcOp, rank);
+  };
+  // Mangle the function name with the rank value as "x<rank>".
+  mlir::func::FuncOp newFunc =
+      getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
+  auto newCall =
+      builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
+  call->replaceAllUsesWith(newCall.getResults());
+  call->dropAllReferences();
+  call->erase();
 }
 
 void SimplifyIntrinsicsPass::runOnOperation() {
@@ -598,7 +682,7 @@ void SimplifyIntrinsicsPass::runOnOperation() {
         //                int dim, const Descriptor *mask)
         //
         if (funcName.startswith(RTNAME_STRING(Sum))) {
-          simplifyReduction(call, kindMap, genRuntimeSumBody);
+          simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
           return;
         }
         if (funcName.startswith(RTNAME_STRING(DotProduct))) {
@@ -669,7 +753,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
           return;
         }
         if (funcName.startswith(RTNAME_STRING(Maxval))) {
-          simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
+          simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
+          return;
+        }
+        if (funcName.startswith(RTNAME_STRING(Count))) {
+          simplifyLogicalReduction(call, kindMap, genRuntimeCountBody);
           return;
         }
       }

diff  --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index dbd23520ef95a..e4a2eaeb486da 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -1098,3 +1098,119 @@ fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13>
 // CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f64
 // CHECK-LABEL: func.func private @_FortranASumReal8x1_fast_simplified
 // CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<fast> : f64
+
+// -----
+// Ensure count is simplified in valid case
+
+func.func @_QMtestPcount_generate_mask(%arg0: !fir.ref<f32> {fir.bindc_name = "a"}) -> i32 {
+  %0 = fir.alloca i32 {bindc_name = "count_generate_mask", uniq_name = "_QMtestFcount_generate_maskEcount_generate_mask"}
+  %c10 = arith.constant 10 : index
+  %1 = fir.alloca !fir.array<10x!fir.logical<4>> {bindc_name = "mask", uniq_name = "_QMtestFcount_generate_maskEmask"}
+  %2 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %3 = fir.embox %1(%2) : (!fir.ref<!fir.array<10x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<10x!fir.logical<4>>>
+  %c0 = arith.constant 0 : index
+  %4 = fir.address_of(@_QQcl.2E2F746573746661696C2E66393000) : !fir.ref<!fir.char<1,15>>
+  %c10_i32 = arith.constant 10 : i32
+  %5 = fir.convert %3 : (!fir.box<!fir.array<10x!fir.logical<4>>>) -> !fir.box<none>
+  %6 = fir.convert %4 : (!fir.ref<!fir.char<1,15>>) -> !fir.ref<i8>
+  %7 = fir.convert %c0 : (index) -> i32
+  %8 = fir.call @_FortranACount(%5, %6, %c10_i32, %7) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
+  %9 = fir.convert %8 : (i64) -> i32
+  fir.store %9 to %0 : !fir.ref<i32>
+  %10 = fir.load %0 : !fir.ref<i32>
+  return %10 : i32
+}
+func.func private @_FortranACount(!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F746573746661696C2E66393000 constant : !fir.char<1,15> {
+  %0 = fir.string_lit "./test.f90\00"(15) : !fir.char<1,15>
+  fir.has_value %0 : !fir.char<1,15>
+}
+
+// CHECK-LABEL:   func.func @_QMtestPcount_generate_mask(
+// CHECK-SAME:                                           %[[A:.*]]: !fir.ref<f32> {fir.bindc_name = "a"}) -> i32 {
+// CHECK:           %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+// CHECK:           %[[A_BOX_LOGICAL:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) : (!fir.ref<!fir.array<10x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<10x!fir.logical<4>>>
+// CHECK:           %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_LOGICAL]] : (!fir.box<!fir.array<10x!fir.logical<4>>>) -> !fir.box<none>
+// CHECK-NOT:       fir.call @_FortranACount({{.*}})
+// CHECK:           %[[RES:.*]] = fir.call @_FortranACountx1_simplified(%[[A_BOX_NONE]]) fastmath<contract> : (!fir.box<none>) -> i64
+// CHECK-NOT:       fir.call @_FortranACount({{.*}})
+// CHECK:           return %{{.*}} : i32
+// CHECK:         }
+// CHECK:         func.func private @_FortranACount(!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64 attributes {fir.runtime}
+
+// CHECK-LABEL:   func.func private @_FortranACountx1_simplified(
+// CHECK-SAME:                                                            %[[ARR:.*]]: !fir.box<none>) -> i64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[C_INDEX0:.*]] = arith.constant 0 : index
+// CHECK:           %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
+// CHECK:           %[[IZERO:.*]] = arith.constant 0 : i64
+// CHECK:           %[[C_INDEX1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK:           %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[C_INDEX1]] : index
+// CHECK:           %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[C_INDEX0]] to %[[EXTENT]] step %[[C_INDEX1]] iter_args(%[[COUNT:.*]] = %[[IZERO]]) -> (i64) {
+// CHECK:             %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+// CHECK:             %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
+// CHECK:             %[[I32_0:.*]] = arith.constant 0 : i32
+// CHECK:             %[[I64_0:.*]] = arith.constant 0 : i64
+// CHECK:             %[[I64_1:.*]] = arith.constant 1 : i64
+// CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[ITEM_VAL]], %[[I32_0]] : i32
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMP]], %[[I64_0]], %[[I64_1]] : i64
+// CHECK:             %[[NEW_COUNT:.*]] = arith.addi %[[SELECT]], %[[COUNT]] : i64
+// CHECK:             fir.result %[[NEW_COUNT]] : i64
+// CHECK:           }
+// CHECK:           return %[[RES:.*]] : i64
+// CHECK:         }
+
+// -----
+// Ensure count isn't simplified when given dim argument
+
+func.func @_QMtestPcount_generate_mask(%arg0: !fir.ref<!fir.array<10x10x!fir.logical<4>>> {fir.bindc_name = "mask"}) -> !fir.array<10xi32> {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
+  %c10 = arith.constant 10 : index
+  %c10_0 = arith.constant 10 : index
+  %c10_1 = arith.constant 10 : index
+  %1 = fir.alloca !fir.array<10xi32> {bindc_name = "res", uniq_name = "_QMtestFcount_generate_maskEres"}
+  %2 = fir.shape %c10_1 : (index) -> !fir.shape<1>
+  %3 = fir.array_load %1(%2) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.array<10xi32>
+  %c2_i32 = arith.constant 2 : i32
+  %4 = fir.shape %c10, %c10_0 : (index, index) -> !fir.shape<2>
+  %5 = fir.embox %arg0(%4) : (!fir.ref<!fir.array<10x10x!fir.logical<4>>>, !fir.shape<2>) -> !fir.box<!fir.array<10x10x!fir.logical<4>>>
+  %c4 = arith.constant 4 : index
+  %6 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %c0 = arith.constant 0 : index
+  %7 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %8 = fir.embox %6(%7) : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.store %8 to %0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %9 = fir.address_of(@_QQcl.2E2F746573746661696C2E66393000) : !fir.ref<!fir.char<1,15>>
+  %c11_i32 = arith.constant 11 : i32
+  %10 = fir.convert %0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+  %11 = fir.convert %5 : (!fir.box<!fir.array<10x10x!fir.logical<4>>>) -> !fir.box<none>
+  %12 = fir.convert %c4 : (index) -> i32
+  %13 = fir.convert %9 : (!fir.ref<!fir.char<1,15>>) -> !fir.ref<i8>
+  %14 = fir.call @_FortranACountDim(%10, %11, %c2_i32, %12, %13, %c11_i32) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+  %15 = fir.load %0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %c0_2 = arith.constant 0 : index
+  %16:3 = fir.box_dims %15, %c0_2 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
+  %17 = fir.box_addr %15 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+  %18 = fir.shape_shift %16#0, %16#1 : (index, index) -> !fir.shapeshift<1>
+  %19 = fir.array_load %17(%18) : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.array<?xi32>
+  %c1 = arith.constant 1 : index
+  %c0_3 = arith.constant 0 : index
+  %20 = arith.subi %c10_1, %c1 : index
+  %21 = fir.do_loop %arg1 = %c0_3 to %20 step %c1 unordered iter_args(%arg2 = %3) -> (!fir.array<10xi32>) {
+    %23 = fir.array_fetch %19, %arg1 : (!fir.array<?xi32>, index) -> i32
+    %24 = fir.array_update %arg2, %23, %arg1 : (!fir.array<10xi32>, i32, index) -> !fir.array<10xi32>
+    fir.result %24 : !fir.array<10xi32>
+  }
+  fir.array_merge_store %3, %21 to %1 : !fir.array<10xi32>, !fir.array<10xi32>, !fir.ref<!fir.array<10xi32>>
+  fir.freemem %17 : !fir.heap<!fir.array<?xi32>>
+  %22 = fir.load %1 : !fir.ref<!fir.array<10xi32>>
+  return %22 : !fir.array<10xi32>
+}
+func.func private @_FortranACountDim(!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none attributes {fir.runtime}
+
+// CHECK-LABEL:   func.func @_QMtestPcount_generate_mask(
+// CHECK-SAME:                                           %[[A:.*]]: !fir.ref<!fir.array<10x10x!fir.logical<4>>> {fir.bindc_name = "mask"}) -> !fir.array<10xi32> {
+// CHECK-NOT        fir.call @_FortranACountDim_simplified({{.*}})
+// CHECK:           %[[RES:.*]] = fir.call @_FortranACountDim({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+// CHECK-NOT        fir.call @_FortranACountDim_simplified({{.*}})


        


More information about the flang-commits mailing list