[flang-commits] [flang] 60f02aa - [flang][hlfir] Fixed KindMapping for HLFIR intrinsics lowering.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Mon Jul 24 10:12:49 PDT 2023


Author: Slava Zakharin
Date: 2023-07-24T10:12:39-07:00
New Revision: 60f02aa7f78c9bd7ffaf816a48700d0adc814d2b

URL: https://github.com/llvm/llvm-project/commit/60f02aa7f78c9bd7ffaf816a48700d0adc814d2b
DIFF: https://github.com/llvm/llvm-project/commit/60f02aa7f78c9bd7ffaf816a48700d0adc814d2b.diff

LOG: [flang][hlfir] Fixed KindMapping for HLFIR intrinsics lowering.

hlfir.count lowering was using incorrect default integer kind
by ignoring the kind specified in the ModuleOp.

Reviewed By: tblah

Differential Revision: https://reviews.llvm.org/D156017

Added: 
    flang/test/HLFIR/count-lowering-default-int-kinds.fir

Modified: 
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/include/flang/Optimizer/Dialect/Support/FIRContext.h
    flang/lib/Optimizer/Dialect/Support/FIRContext.cpp
    flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
    flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
    flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
    flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
    flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
    flang/test/HLFIR/count-lowering.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 451d0f12945527..9aa6fa9659d062 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -63,6 +63,8 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
       setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
     }
   }
+  FirOpBuilder(mlir::OpBuilder &builder, mlir::Operation *op)
+      : FirOpBuilder(builder, fir::getKindMapping(op), op) {}
 
   // The listener self-reference has to be updated in case of copy-construction.
   FirOpBuilder(const FirOpBuilder &other)

diff  --git a/flang/include/flang/Optimizer/Dialect/Support/FIRContext.h b/flang/include/flang/Optimizer/Dialect/Support/FIRContext.h
index 0888a846456931..ad91da6a4432da 100644
--- a/flang/include/flang/Optimizer/Dialect/Support/FIRContext.h
+++ b/flang/include/flang/Optimizer/Dialect/Support/FIRContext.h
@@ -22,6 +22,7 @@
 
 namespace mlir {
 class ModuleOp;
+class Operation;
 } // namespace mlir
 
 namespace fir {
@@ -43,6 +44,12 @@ void setKindMapping(mlir::ModuleOp mod, KindMapping &kindMap);
 /// default.
 KindMapping getKindMapping(mlir::ModuleOp mod);
 
+/// Get the KindMapping instance that is in effect for the specified
+/// operation. The KindMapping is taken from the operation itself,
+/// if the operation is a ModuleOp, or from its parent ModuleOp.
+/// If a ModuleOp cannot be reached, the function returns default KindMapping.
+KindMapping getKindMapping(mlir::Operation *op);
+
 /// Helper for determining the target from the host, etc. Tools may use this
 /// function to provide a consistent interpretation of the `--target=<string>`
 /// command-line option.

diff  --git a/flang/lib/Optimizer/Dialect/Support/FIRContext.cpp b/flang/lib/Optimizer/Dialect/Support/FIRContext.cpp
index a2b77ebcf52ae1..1b2d1c88c86b5b 100644
--- a/flang/lib/Optimizer/Dialect/Support/FIRContext.cpp
+++ b/flang/lib/Optimizer/Dialect/Support/FIRContext.cpp
@@ -51,6 +51,15 @@ fir::KindMapping fir::getKindMapping(mlir::ModuleOp mod) {
   return fir::KindMapping(ctx);
 }
 
+fir::KindMapping fir::getKindMapping(mlir::Operation *op) {
+  auto moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);
+  if (moduleOp)
+    return getKindMapping(moduleOp);
+
+  moduleOp = op->getParentOfType<mlir::ModuleOp>();
+  return getKindMapping(moduleOp);
+}
+
 std::string fir::determineTargetTriple(llvm::StringRef triple) {
   // Treat "" or "default" as stand-ins for the default machine.
   if (triple.empty() || triple == "default")

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index bb110d4a80a2cf..09db0b19c78018 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -253,8 +253,7 @@ struct ApplyOpConversion : public mlir::OpConversionPattern<hlfir::ApplyOp> {
     if (fir::isa_trivial(apply.getType())) {
       result = rewriter.create<fir::LoadOp>(loc, result);
     } else {
-      auto module = apply->getParentOfType<mlir::ModuleOp>();
-      fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+      fir::FirOpBuilder builder(rewriter, apply.getOperation());
       result =
           packageBufferizedExpr(loc, builder, hlfir::Entity{result}, false);
     }
@@ -288,8 +287,7 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
   matchAndRewrite(hlfir::ConcatOp concat, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = concat->getLoc();
-    auto module = concat->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, concat.getOperation());
     assert(adaptor.getStrings().size() >= 2 &&
            "must have at least two strings operands");
     if (adaptor.getStrings().size() > 2)
@@ -328,8 +326,7 @@ struct SetLengthOpConversion
   matchAndRewrite(hlfir::SetLengthOp setLength, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = setLength->getLoc();
-    auto module = setLength->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, setLength.getOperation());
     // Create a temp with the new length.
     hlfir::Entity string = getBufferizedExprStorage(adaptor.getString());
     auto charType = hlfir::getFortranElementType(setLength.getType());
@@ -362,8 +359,7 @@ struct GetLengthOpConversion
   matchAndRewrite(hlfir::GetLengthOp getLength, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = getLength->getLoc();
-    auto module = getLength->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, getLength.getOperation());
     hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
     mlir::Value length = hlfir::genCharLength(loc, builder, bufferizedExpr);
     if (!length)
@@ -436,8 +432,7 @@ struct AssociateOpConversion
   matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = associate->getLoc();
-    auto module = associate->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, associate.getOperation());
     mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource());
     const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType());
 
@@ -577,8 +572,7 @@ struct EndAssociateOpConversion
   matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = endAssociate->getLoc();
-    auto module = endAssociate->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, endAssociate.getOperation());
     genFreeIfMustFree(loc, builder, adaptor.getVar(), adaptor.getMustFree());
     rewriter.eraseOp(endAssociate);
     return mlir::success();
@@ -597,8 +591,7 @@ struct DestroyOpConversion
     mlir::Location loc = destroy->getLoc();
     hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr());
     if (!fir::isa_trivial(bufferizedExpr.getType())) {
-      auto module = destroy->getParentOfType<mlir::ModuleOp>();
-      fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+      fir::FirOpBuilder builder(rewriter, destroy.getOperation());
       mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr());
       mlir::Value firBase = bufferizedExpr.getFirBase();
       genFreeIfMustFree(loc, builder, firBase, mustFree);
@@ -617,8 +610,7 @@ struct NoReassocOpConversion
   matchAndRewrite(hlfir::NoReassocOp noreassoc, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = noreassoc->getLoc();
-    auto module = noreassoc->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, noreassoc.getOperation());
     mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getVal());
     mlir::Value result =
         builder.create<hlfir::NoReassocOp>(loc, bufferizedExpr);
@@ -677,8 +669,7 @@ struct ElementalOpConversion
   matchAndRewrite(hlfir::ElementalOp elemental, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = elemental->getLoc();
-    auto module = elemental->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, elemental.getOperation());
     // The body of the elemental op may contain operation that will require
     // to be translated. Notify the rewriter about the cloned operations.
     HLFIRListener listener{builder, rewriter};
@@ -743,11 +734,10 @@ struct CharExtremumOpConversion
   matchAndRewrite(hlfir::CharExtremumOp char_extremum, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = char_extremum->getLoc();
-    auto module = char_extremum->getParentOfType<mlir::ModuleOp>();
     auto predicate = char_extremum.getPredicate();
     bool predIsMin =
         predicate == hlfir::CharExtremumPredicate::min ? true : false;
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, char_extremum.getOperation());
     assert(adaptor.getStrings().size() >= 2 &&
            "must have at least two strings operands");
     auto numOperands = adaptor.getStrings().size();

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
index aa893e4805451a..835ff5de02ebbb 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
@@ -240,8 +240,7 @@ class CopyInOpConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
   matchAndRewrite(hlfir::CopyInOp copyInOp,
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = copyInOp.getLoc();
-    auto module = copyInOp->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, copyInOp.getOperation());
     CopyInResult result = copyInOp.getVarIsPresent()
                               ? genOptionalCopyIn(loc, builder, copyInOp)
                               : genNonOptionalCopyIn(loc, builder, copyInOp);
@@ -259,8 +258,7 @@ class CopyOutOpConversion : public mlir::OpRewritePattern<hlfir::CopyOutOp> {
   matchAndRewrite(hlfir::CopyOutOp copyOutOp,
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = copyOutOp.getLoc();
-    auto module = copyOutOp->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, copyOutOp.getOperation());
 
     builder.genIfThen(loc, copyOutOp.getWasCopied())
         .genThen([&]() {
@@ -323,8 +321,7 @@ class DeclareOpConversion : public mlir::OpRewritePattern<hlfir::DeclareOp> {
     mlir::Value hlfirBase;
     mlir::Type hlfirBaseType = declareOp.getBase().getType();
     if (hlfirBaseType.isa<fir::BaseBoxType>()) {
-      auto module = declareOp->getParentOfType<mlir::ModuleOp>();
-      fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+      fir::FirOpBuilder builder(rewriter, declareOp.getOperation());
       // Helper to generate the hlfir fir.box with the local lower bounds and
       // type parameters.
       auto genHlfirBox = [&]() -> mlir::Value {
@@ -423,8 +420,7 @@ class DesignateOpConversion
   matchAndRewrite(hlfir::DesignateOp designate,
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = designate.getLoc();
-    auto module = designate->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    fir::FirOpBuilder builder(rewriter, designate.getOperation());
 
     hlfir::Entity baseEntity(designate.getMemref());
 

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index fe4201b79cffe2..bbc989f9e046e7 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -14,7 +14,6 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
-#include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -91,8 +90,7 @@ class InlineElementalConversion
     assert(elemental.getRegion().hasOneBlock() &&
            "expect elemental region to have one block");
 
-    fir::FirOpBuilder builder{rewriter,
-                              fir::KindMapping{rewriter.getContext()}};
+    fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
     builder.setInsertionPointAfter(apply);
     hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
         elemental.getLoc(), builder, elemental, apply.getIndices());

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 0e40e2f8c95297..f30d79ebaa9b40 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -69,8 +69,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
                  mlir::PatternRewriter &rewriter,
                  const fir::IntrinsicArgumentLoweringRules *argLowering) const {
     mlir::Location loc = op->getLoc();
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, op};
+    fir::FirOpBuilder builder{rewriter, op};
 
     llvm::SmallVector<fir::ExtendedValue, 3> ret;
     llvm::SmallVector<std::function<void()>, 2> cleanupFns;
@@ -229,8 +228,7 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
       return mlir::failure();
     }
 
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, operation};
+    fir::FirOpBuilder builder{rewriter, operation.getOperation()};
     const mlir::Location &loc = operation->getLoc();
 
     mlir::Type i32 = builder.getI32Type();
@@ -271,8 +269,7 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
   mlir::LogicalResult
   matchAndRewrite(hlfir::CountOp count,
                   mlir::PatternRewriter &rewriter) const override {
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, count};
+    fir::FirOpBuilder builder{rewriter, count.getOperation()};
     const mlir::Location &loc = count->getLoc();
 
     mlir::Type i32 = builder.getI32Type();
@@ -304,8 +301,7 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
   mlir::LogicalResult
   matchAndRewrite(hlfir::MatmulOp matmul,
                   mlir::PatternRewriter &rewriter) const override {
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, matmul};
+    fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
     const mlir::Location &loc = matmul->getLoc();
 
     mlir::Value lhs = matmul.getLhs();
@@ -336,8 +332,7 @@ struct DotProductOpConversion
   mlir::LogicalResult
   matchAndRewrite(hlfir::DotProductOp dotProduct,
                   mlir::PatternRewriter &rewriter) const override {
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, dotProduct};
+    fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
     const mlir::Location &loc = dotProduct->getLoc();
 
     mlir::Value lhs = dotProduct.getLhs();
@@ -368,8 +363,7 @@ class TransposeOpConversion
   mlir::LogicalResult
   matchAndRewrite(hlfir::TransposeOp transpose,
                   mlir::PatternRewriter &rewriter) const override {
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, transpose};
+    fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
     const mlir::Location &loc = transpose->getLoc();
 
     mlir::Value arg = transpose.getArray();
@@ -399,8 +393,7 @@ struct MatmulTransposeOpConversion
   mlir::LogicalResult
   matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
                   mlir::PatternRewriter &rewriter) const override {
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping, multranspose};
+    fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
     const mlir::Location &loc = multranspose->getLoc();
 
     mlir::Value lhs = multranspose.getLhs();

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 141c41d1a63ce1..6206deee411c34 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -13,7 +13,6 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
-#include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
@@ -40,8 +39,7 @@ class TransposeAsElementalConversion
   matchAndRewrite(hlfir::TransposeOp transpose,
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = transpose.getLoc();
-    fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
     hlfir::ExprType expr = transpose.getType();
     mlir::Type elementType = expr.getElementType();
     hlfir::Entity array = hlfir::Entity{transpose.getArray()};

diff  --git a/flang/test/HLFIR/count-lowering-default-int-kinds.fir b/flang/test/HLFIR/count-lowering-default-int-kinds.fir
new file mode 100644
index 00000000000000..b34639a0aa0af0
--- /dev/null
+++ b/flang/test/HLFIR/count-lowering-default-int-kinds.fir
@@ -0,0 +1,47 @@
+// Test hlfir.count operation lowering with 
diff erent default integer kinds.
+// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s
+
+module attributes {fir.defaultkind = "a1c4d8i8l4r4", fir.kindmap = ""} {
+  func.func @test_i8(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+    %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi64>
+    return
+  }
+}
+// CHECK-LABEL: func.func @test_i8
+// CHECK: %[[KIND:.*]] = arith.constant 8 : index
+// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
+// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+
+module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = ""} {
+  func.func @test_i4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+    %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi32>
+    return
+  }
+}
+// CHECK-LABEL: func.func @test_i4
+// CHECK: %[[KIND:.*]] = arith.constant 4 : index
+// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
+// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+
+module attributes {fir.defaultkind = "a1c4d8i2l4r4", fir.kindmap = ""} {
+  func.func @test_i2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+    %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi16>
+    return
+  }
+}
+// CHECK-LABEL: func.func @test_i2
+// CHECK: %[[KIND:.*]] = arith.constant 2 : index
+// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
+// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+
+module attributes {fir.defaultkind = "a1c4d8i1l4r4", fir.kindmap = ""} {
+  func.func @test_i1(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+    %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi8>
+    return
+  }
+}
+// CHECK-LABEL: func.func @test_i1
+// CHECK: arith.constant 1 : index
+// CHECK: %[[KIND:.*]] = arith.constant 1 : index
+// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
+// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none

diff  --git a/flang/test/HLFIR/count-lowering.fir b/flang/test/HLFIR/count-lowering.fir
index 0d9cc34a316eb8..b391a651a6ae7f 100644
--- a/flang/test/HLFIR/count-lowering.fir
+++ b/flang/test/HLFIR/count-lowering.fir
@@ -39,6 +39,7 @@ func.func @_QPcount2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
 
 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
+// CHECK-DAG:     %[[KIND:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?xi32>>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
@@ -48,8 +49,9 @@ func.func @_QPcount2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK-DAG:     %[[DIM:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref<i32>
 // CHECK-DAG:     %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]]
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1
+// CHECK-DAG:     %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
 
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[KIND_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
 // CHECK:         %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
@@ -80,6 +82,7 @@ func.func @_QPcount3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"})
 // CHECK-LABEL:  func.func @_QPcount3(
 // CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<2xi32>>
 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
+// CHECK-DAG:     %[[KIND:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?xi32>>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
@@ -95,7 +98,9 @@ func.func @_QPcount3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"})
 
 // CHECK-DAG:     %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]]
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>) -> !fir.box<none>
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+// CHECK-DAG:     %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32
+
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[KIND_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
 // CHECK:         %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]


        


More information about the flang-commits mailing list