[flang-commits] [flang] d2d2130 - [flang][hlfir][NFC] refactor transformational intrinsic lowering

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Tue Jul 4 02:35:28 PDT 2023


Author: Tom Eccles
Date: 2023-07-04T09:34:43Z
New Revision: d2d213018df6d2d6e8b9c4f6d88d52ee5a94fb49

URL: https://github.com/llvm/llvm-project/commit/d2d213018df6d2d6e8b9c4f6d88d52ee5a94fb49
DIFF: https://github.com/llvm/llvm-project/commit/d2d213018df6d2d6e8b9c4f6d88d52ee5a94fb49.diff

LOG: [flang][hlfir][NFC] refactor transformational intrinsic lowering

The old code had overgrown itself and become difficult to read and
modify. I've rewritten it and moved it into its own translation unit.

I moved PreparedActualArgument to the header file for the
transformational intrinsic lowering. Logically, it belongs in
ConvertCall.h, but putting it there would create a circular dependency
between HlfirIntrinsics and ConvertCall.

Differential Revision: https://reviews.llvm.org/D154235

Added: 
    flang/include/flang/Lower/HlfirIntrinsics.h
    flang/lib/Lower/HlfirIntrinsics.cpp

Modified: 
    flang/lib/Lower/CMakeLists.txt
    flang/lib/Lower/ConvertCall.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/HlfirIntrinsics.h b/flang/include/flang/Lower/HlfirIntrinsics.h
new file mode 100644
index 00000000000000..df1f1ac9a7cf5a
--- /dev/null
+++ b/flang/include/flang/Lower/HlfirIntrinsics.h
@@ -0,0 +1,90 @@
+//===-- HlfirIntrinsics.h -- lowering to HLFIR intrinsic ops ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+///
+/// Implements lowering of transformational intrinsics to HLFIR intrinsic
+/// operations
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_HLFIRINTRINSICS_H
+#define FORTRAN_LOWER_HLFIRINTRINSICS_H
+
+#include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "llvm/ADT/SmallVector.h"
+#include <cassert>
+#include <optional>
+#include <string>
+
+namespace mlir {
+class Location;
+class Type;
+class Value;
+class ValueRange;
+} // namespace mlir
+
+namespace fir {
+class FirOpBuilder;
+struct IntrinsicArgumentLoweringRules;
+} // namespace fir
+
+namespace Fortran::lower {
+
+/// This structure holds the initial lowered value of an actual argument that
+/// was lowered regardless of the interface, and it holds whether or not it
+/// may be absent at runtime and the dummy is optional.
+struct PreparedActualArgument {
+
+  PreparedActualArgument(hlfir::Entity actual,
+                         std::optional<mlir::Value> isPresent)
+      : actual{actual}, isPresent{isPresent} {}
+  void setElementalIndices(mlir::ValueRange &indices) {
+    oneBasedElementalIndices = &indices;
+  }
+  hlfir::Entity getActual(mlir::Location loc,
+                          fir::FirOpBuilder &builder) const {
+    if (oneBasedElementalIndices)
+      return hlfir::getElementAt(loc, builder, actual,
+                                 *oneBasedElementalIndices);
+    return actual;
+  }
+  hlfir::Entity getOriginalActual() const { return actual; }
+  void setOriginalActual(hlfir::Entity newActual) { actual = newActual; }
+  bool handleDynamicOptional() const { return isPresent.has_value(); }
+  mlir::Value getIsPresent() const {
+    assert(handleDynamicOptional() && "not a dynamic optional");
+    return *isPresent;
+  }
+
+  void resetOptionalAspect() { isPresent = std::nullopt; }
+
+private:
+  hlfir::Entity actual;
+  mlir::ValueRange *oneBasedElementalIndices{nullptr};
+  // When the actual may be dynamically optional, "isPresent"
+  // holds a boolean value indicating the presence of the
+  // actual argument at runtime.
+  std::optional<mlir::Value> isPresent;
+};
+
+/// Vector of pre-lowered actual arguments. nullopt if the actual is
+/// "statically" absent (if it was not syntactically  provided).
+using PreparedActualArguments =
+    llvm::SmallVector<std::optional<PreparedActualArgument>>;
+
+std::optional<hlfir::EntityWithAttributes> lowerHlfirIntrinsic(
+    fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType);
+
+} // namespace Fortran::lower
+#endif // FORTRAN_LOWER_HLFIRINTRINSICS_H

diff  --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt
index a0aa4844d8c2b5..b13d415e02f1d9 100644
--- a/flang/lib/Lower/CMakeLists.txt
+++ b/flang/lib/Lower/CMakeLists.txt
@@ -6,6 +6,7 @@ add_flang_library(FortranLower
   Bridge.cpp
   CallInterface.cpp
   Coarray.cpp
+  ComponentPath.cpp
   ConvertArrayConstructor.cpp
   ConvertCall.cpp
   ConvertConstant.cpp
@@ -14,9 +15,9 @@ add_flang_library(FortranLower
   ConvertProcedureDesignator.cpp
   ConvertType.cpp
   ConvertVariable.cpp
-  ComponentPath.cpp
   CustomIntrinsicCall.cpp
   DumpEvaluateExpr.cpp
+  HlfirIntrinsics.cpp
   HostAssociations.cpp
   IO.cpp
   IterationSpace.cpp

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 3ba7d3cfb6ea6e..9bbb9468831ad9 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -14,6 +14,7 @@
 #include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/CustomIntrinsicCall.h"
+#include "flang/Lower/HlfirIntrinsics.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -615,50 +616,8 @@ struct CallContext {
   std::optional<mlir::Type> resultType;
   mlir::Location loc;
 };
-
-/// This structure holds the initial lowered value of an actual argument that
-/// was lowered regardless of the interface, and it holds whether or not it
-/// may be absent at runtime and the dummy is optional.
-struct PreparedActualArgument {
-
-  PreparedActualArgument(hlfir::Entity actual,
-                         std::optional<mlir::Value> isPresent)
-      : actual{actual}, isPresent{isPresent} {}
-  void setElementalIndices(mlir::ValueRange &indices) {
-    oneBasedElementalIndices = &indices;
-  }
-  hlfir::Entity getActual(mlir::Location loc,
-                          fir::FirOpBuilder &builder) const {
-    if (oneBasedElementalIndices)
-      return hlfir::getElementAt(loc, builder, actual,
-                                 *oneBasedElementalIndices);
-    return actual;
-  }
-  hlfir::Entity getOriginalActual() const { return actual; }
-  void setOriginalActual(hlfir::Entity newActual) { actual = newActual; }
-  bool handleDynamicOptional() const { return isPresent.has_value(); }
-  mlir::Value getIsPresent() const {
-    assert(handleDynamicOptional() && "not a dynamic optional");
-    return *isPresent;
-  }
-
-  void resetOptionalAspect() { isPresent = std::nullopt; }
-
-private:
-  hlfir::Entity actual;
-  mlir::ValueRange *oneBasedElementalIndices{nullptr};
-  // When the actual may be dynamically optional, "isPresent"
-  // holds a boolean value indicating the presence of the
-  // actual argument at runtime.
-  std::optional<mlir::Value> isPresent;
-};
 } // namespace
 
-/// Vector of pre-lowered actual arguments. nullopt if the actual is
-/// "statically" absent (if it was not syntactically  provided).
-using PreparedActualArguments =
-    llvm::SmallVector<std::optional<PreparedActualArgument>>;
-
 // Helper to transform a fir::ExtendedValue to an hlfir::EntityWithAttributes.
 static hlfir::EntityWithAttributes
 extendedValueToHlfirEntity(mlir::Location loc, fir::FirOpBuilder &builder,
@@ -860,7 +819,8 @@ static hlfir::Entity fixProcedureDummyMismatch(mlir::Location loc,
 /// The optional aspects must be handled by this function user.
 static PreparedDummyArgument preparePresentUserCallActualArgument(
     mlir::Location loc, fir::FirOpBuilder &builder,
-    const PreparedActualArgument &preparedActual, mlir::Type dummyType,
+    const Fortran::lower::PreparedActualArgument &preparedActual,
+    mlir::Type dummyType,
     const Fortran::lower::CallerInterface::PassedEntity &arg,
     const Fortran::lower::SomeExpr &expr,
     Fortran::lower::AbstractConverter &converter) {
@@ -1023,7 +983,8 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
 /// of any optional aspect.
 static PreparedDummyArgument prepareUserCallActualArgument(
     mlir::Location loc, fir::FirOpBuilder &builder,
-    const PreparedActualArgument &preparedActual, mlir::Type dummyType,
+    const Fortran::lower::PreparedActualArgument &preparedActual,
+    mlir::Type dummyType,
     const Fortran::lower::CallerInterface::PassedEntity &arg,
     const Fortran::lower::SomeExpr &expr,
     Fortran::lower::AbstractConverter &converter) {
@@ -1094,7 +1055,7 @@ static PreparedDummyArgument prepareUserCallActualArgument(
 /// the array argument elements value and will return the corresponding
 /// scalar result value.
 static std::optional<hlfir::EntityWithAttributes>
-genUserCall(PreparedActualArguments &loweredActuals,
+genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
             Fortran::lower::CallerInterface &caller,
             mlir::FunctionType callSiteType, CallContext &callContext) {
   using PassBy = Fortran::lower::CallerInterface::PassEntityBy;
@@ -1221,7 +1182,7 @@ genUserCall(PreparedActualArguments &loweredActuals,
 /// Lower calls to intrinsic procedures with actual arguments that have been
 /// pre-lowered but have not yet been prepared according to the interface.
 static std::optional<hlfir::EntityWithAttributes>
-genIntrinsicRefCore(PreparedActualArguments &loweredActuals,
+genIntrinsicRefCore(Fortran::lower::PreparedActualArguments &loweredActuals,
                     const Fortran::evaluate::SpecificIntrinsic *intrinsic,
                     const fir::IntrinsicArgumentLoweringRules *argLowering,
                     CallContext &callContext) {
@@ -1343,199 +1304,29 @@ genIntrinsicRefCore(PreparedActualArguments &loweredActuals,
 
 /// Lower calls to intrinsic procedures with actual arguments that have been
 /// pre-lowered but have not yet been prepared according to the interface.
-static std::optional<hlfir::EntityWithAttributes>
-genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals,
-                         const Fortran::evaluate::SpecificIntrinsic *intrinsic,
-                         const fir::IntrinsicArgumentLoweringRules *argLowering,
-                         CallContext &callContext) {
+static std::optional<hlfir::EntityWithAttributes> genHLFIRIntrinsicRefCore(
+    Fortran::lower::PreparedActualArguments &loweredActuals,
+    const Fortran::evaluate::SpecificIntrinsic *intrinsic,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    CallContext &callContext) {
   if (!useHlfirIntrinsicOps)
     return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
                                callContext);
 
   fir::FirOpBuilder &builder = callContext.getBuilder();
   mlir::Location loc = callContext.loc;
-
-  auto getOperandVector = [&](PreparedActualArguments &loweredActuals) {
-    llvm::SmallVector<mlir::Value> operands;
-    operands.reserve(loweredActuals.size());
-
-    for (size_t i = 0; i < loweredActuals.size(); ++i) {
-      std::optional<PreparedActualArgument> arg = loweredActuals[i];
-      if (!arg) {
-        operands.emplace_back();
-        continue;
-      }
-      hlfir::Entity actual = arg->getOriginalActual();
-      mlir::Value valArg;
-
-      // if intrinsic handler has no lowering rules
-      if (!argLowering) {
-        valArg = hlfir::loadTrivialScalar(loc, builder, actual);
-      } else {
-        fir::ArgLoweringRule argRules =
-            fir::lowerIntrinsicArgumentAs(*argLowering, i);
-        if (!argRules.handleDynamicOptional &&
-            argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
-          valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
-        else
-          valArg = actual.getBase();
-      }
-
-      operands.emplace_back(valArg);
-    }
-    return operands;
-  };
-
-  auto computeResultType = [&](mlir::Value argArray,
-                               mlir::Type stmtResultType) -> mlir::Type {
-    hlfir::ExprType::Shape resultShape;
-    mlir::Type normalisedResult =
-        hlfir::getFortranElementOrSequenceType(stmtResultType);
-    mlir::Type elementType;
-    if (auto array = normalisedResult.dyn_cast<fir::SequenceType>()) {
-      resultShape = hlfir::ExprType::Shape{array.getShape()};
-      elementType = array.getEleTy();
-      return hlfir::ExprType::get(builder.getContext(), resultShape,
-                                  elementType,
-                                  /*polymorphic=*/false);
-    }
-    elementType = normalisedResult;
-    return elementType;
-  };
-
-  auto buildSumOperation = [](fir::FirOpBuilder &builder, mlir::Location loc,
-                              mlir::Type resultTy, mlir::Value array,
-                              mlir::Value dim, mlir::Value mask) {
-    return builder.create<hlfir::SumOp>(loc, resultTy, array, dim, mask);
-  };
-
-  auto buildProductOperation = [](fir::FirOpBuilder &builder,
-                                  mlir::Location loc, mlir::Type resultTy,
-                                  mlir::Value array, mlir::Value dim,
-                                  mlir::Value mask) {
-    return builder.create<hlfir::ProductOp>(loc, resultTy, array, dim, mask);
-  };
-
-  auto buildAnyOperation = [](fir::FirOpBuilder &builder, mlir::Location loc,
-                              mlir::Type resultTy, mlir::Value array,
-                              mlir::Value dim, mlir::Value mask) {
-    return builder.create<hlfir::AnyOp>(loc, resultTy, array, dim);
-  };
-
-  auto buildAllOperation = [](fir::FirOpBuilder &builder, mlir::Location loc,
-                              mlir::Type resultTy, mlir::Value array,
-                              mlir::Value dim, mlir::Value mask) {
-    return builder.create<hlfir::AllOp>(loc, resultTy, array, dim);
-  };
-
-  auto buildReductionIntrinsic =
-      [&](PreparedActualArguments &loweredActuals, mlir::Location loc,
-          fir::FirOpBuilder &builder, CallContext &callContext,
-          std::function<mlir::Operation *(fir::FirOpBuilder &, mlir::Location,
-                                          mlir::Type, mlir::Value, mlir::Value,
-                                          mlir::Value)>
-              buildFunc,
-          bool hasMask) -> std::optional<hlfir::EntityWithAttributes> {
-    // shared logic for building the product and sum operations
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    // dim, mask can be NULL if these arguments were not given
-    mlir::Value array = operands[0];
-    mlir::Value dim = operands[1];
-    if (dim)
-      dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
-
-    mlir::Value mask;
-    if (hasMask)
-      mask = operands[2];
-
-    mlir::Type resultTy = computeResultType(array, *callContext.resultType);
-    auto *intrinsicOp = buildFunc(builder, loc, resultTy, array, dim, mask);
-    return {hlfir::EntityWithAttributes{intrinsicOp->getResult(0)}};
-  };
-
   const std::string intrinsicName = callContext.getProcedureName();
-  if (intrinsicName == "sum") {
-    return buildReductionIntrinsic(loweredActuals, loc, builder, callContext,
-                                   buildSumOperation, true);
-  }
-  if (intrinsicName == "product") {
-    return buildReductionIntrinsic(loweredActuals, loc, builder, callContext,
-                                   buildProductOperation, true);
-  }
-  if (intrinsicName == "matmul") {
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    mlir::Type resultTy =
-        computeResultType(operands[0], *callContext.resultType);
-    hlfir::MatmulOp matmulOp = builder.create<hlfir::MatmulOp>(
-        loc, resultTy, operands[0], operands[1]);
-
-    return {hlfir::EntityWithAttributes{matmulOp.getResult()}};
-  }
-  if (intrinsicName == "transpose") {
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    hlfir::ExprType::Shape resultShape;
-    mlir::Type normalisedResult =
-        hlfir::getFortranElementOrSequenceType(*callContext.resultType);
-    auto array = normalisedResult.cast<fir::SequenceType>();
-    llvm::ArrayRef<int64_t> arrayShape = array.getShape();
-    assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
-    mlir::Type elementType = array.getEleTy();
-    resultShape.push_back(arrayShape[0]);
-    resultShape.push_back(arrayShape[1]);
-    mlir::Type resultTy = hlfir::ExprType::get(
-        builder.getContext(), resultShape, elementType, /*polymorphic=*/false);
-    hlfir::TransposeOp transposeOp =
-        builder.create<hlfir::TransposeOp>(loc, resultTy, operands[0]);
-
-    return {hlfir::EntityWithAttributes{transposeOp.getResult()}};
-  }
-  if (intrinsicName == "any") {
-    return buildReductionIntrinsic(loweredActuals, loc, builder, callContext,
-                                   buildAnyOperation, false);
-  }
-  if (intrinsicName == "all") {
-    return buildReductionIntrinsic(loweredActuals, loc, builder, callContext,
-                                   buildAllOperation, false);
-  }
-  if (intrinsicName == "dot_product") {
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    mlir::Type resultTy =
-        computeResultType(operands[0], *callContext.resultType);
-    hlfir::DotProductOp dotProductOp = builder.create<hlfir::DotProductOp>(
-        loc, resultTy, operands[0], operands[1]);
-
-    return {hlfir::EntityWithAttributes{dotProductOp.getResult()}};
-  }
-  if (intrinsicName == "count") {
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    mlir::Value array = operands[0];
-    mlir::Value dim = operands[1];
-    if (dim)
-      dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
-    mlir::Value kind = operands[2];
-    mlir::Type resultTy = computeResultType(array, *callContext.resultType);
-    hlfir::CountOp countOp =
-        builder.create<hlfir::CountOp>(loc, resultTy, array, dim, kind);
-    return {hlfir::EntityWithAttributes{countOp.getResult()}};
-  }
 
-  if ((intrinsicName == "min" || intrinsicName == "max") &&
-      hlfir::getFortranElementType(callContext.resultType.value())
-          .isa<fir::CharacterType>()) {
-    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
-    assert(operands.size() >= 2);
-
-    hlfir::CharExtremumPredicate pred = (intrinsicName == "min")
-                                            ? hlfir::CharExtremumPredicate::min
-                                            : hlfir::CharExtremumPredicate::max;
-    hlfir::CharExtremumOp charExtremumOp =
-        builder.create<hlfir::CharExtremumOp>(loc, pred,
-                                              mlir::ValueRange{operands});
-    return {hlfir::EntityWithAttributes{charExtremumOp.getResult()}};
+  // transformational intrinsic ops always have a result type
+  if (callContext.resultType) {
+    std::optional<hlfir::EntityWithAttributes> res =
+        Fortran::lower::lowerHlfirIntrinsic(builder, loc, intrinsicName,
+                                            loweredActuals, argLowering,
+                                            *callContext.resultType);
+    if (res)
+      return res;
   }
 
-  // TODO add hlfir operations for other transformational intrinsics here
-
   // fallback to calling the intrinsic via fir.call
   return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
                              callContext);
@@ -1546,14 +1337,14 @@ template <typename ElementalCallBuilderImpl>
 class ElementalCallBuilder {
 public:
   std::optional<hlfir::EntityWithAttributes>
-  genElementalCall(PreparedActualArguments &loweredActuals, bool isImpure,
-                   CallContext &callContext) {
+  genElementalCall(Fortran::lower::PreparedActualArguments &loweredActuals,
+                   bool isImpure, CallContext &callContext) {
     mlir::Location loc = callContext.loc;
     fir::FirOpBuilder &builder = callContext.getBuilder();
     unsigned numArgs = loweredActuals.size();
     // Step 1: dereference pointers/allocatables and compute elemental shape.
     mlir::Value shape;
-    PreparedActualArgument *optionalWithShape;
+    Fortran::lower::PreparedActualArgument *optionalWithShape;
     // 10.1.4 p5. Impure elemental procedures must be called in element order.
     bool mustBeOrdered = isImpure;
     for (unsigned i = 0; i < numArgs; ++i) {
@@ -1693,7 +1484,7 @@ class ElementalUserCallBuilder
                            mlir::FunctionType callSiteType)
       : caller{caller}, callSiteType{callSiteType} {}
   std::optional<hlfir::Entity>
-  genElementalKernel(PreparedActualArguments &loweredActuals,
+  genElementalKernel(Fortran::lower::PreparedActualArguments &loweredActuals,
                      CallContext &callContext) {
     return genUserCall(loweredActuals, caller, callSiteType, callContext);
   }
@@ -1714,9 +1505,9 @@ class ElementalUserCallBuilder
            arg.passBy == PassBy::BaseAddressValueAttribute;
   }
 
-  mlir::Value
-  computeDynamicCharacterResultLength(PreparedActualArguments &loweredActuals,
-                                      CallContext &callContext) {
+  mlir::Value computeDynamicCharacterResultLength(
+      Fortran::lower::PreparedActualArguments &loweredActuals,
+      CallContext &callContext) {
     TODO(callContext.loc,
          "compute elemental function result length parameters in HLFIR");
   }
@@ -1735,7 +1526,7 @@ class ElementalIntrinsicCallBuilder
       : intrinsic{intrinsic}, argLowering{argLowering}, isFunction{isFunction} {
   }
   std::optional<hlfir::Entity>
-  genElementalKernel(PreparedActualArguments &loweredActuals,
+  genElementalKernel(Fortran::lower::PreparedActualArguments &loweredActuals,
                      CallContext &callContext) {
     return genHLFIRIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
                                     callContext);
@@ -1748,9 +1539,9 @@ class ElementalIntrinsicCallBuilder
     return isFunction;
   }
 
-  mlir::Value
-  computeDynamicCharacterResultLength(PreparedActualArguments &loweredActuals,
-                                      CallContext &callContext) {
+  mlir::Value computeDynamicCharacterResultLength(
+      Fortran::lower::PreparedActualArguments &loweredActuals,
+      CallContext &callContext) {
     if (intrinsic)
       if (intrinsic->name == "adjustr" || intrinsic->name == "adjustl" ||
           intrinsic->name == "merge")
@@ -1816,7 +1607,7 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
                        callContext.procRef, *intrinsic, converter))
     TODO(loc, "special cases of intrinsic with optional arguments");
 
-  PreparedActualArguments loweredActuals;
+  Fortran::lower::PreparedActualArguments loweredActuals;
   const fir::IntrinsicArgumentLoweringRules *argLowering =
       fir::getIntrinsicArgumentLowering(callContext.getProcedureName());
   for (const auto &arg : llvm::enumerate(callContext.procRef.arguments())) {
@@ -1845,7 +1636,7 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
            !fir::lowerIntrinsicArgumentAs(*argLowering, arg.index())
                 .handleDynamicOptional) &&
           "TYPE(*) are not expected to appear as optional intrinsic arguments");
-      loweredActuals.push_back(PreparedActualArgument{
+      loweredActuals.push_back(Fortran::lower::PreparedActualArgument{
           hlfir::Entity{*var}, /*isPresent=*/std::nullopt});
       continue;
     }
@@ -1861,7 +1652,8 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
             genIsPresentIfArgMaybeAbsent(loc, loweredActual, *expr, callContext,
                                          /*passAsAllocatableOrPointer=*/false);
     }
-    loweredActuals.push_back(PreparedActualArgument{loweredActual, isPresent});
+    loweredActuals.push_back(
+        Fortran::lower::PreparedActualArgument{loweredActual, isPresent});
   }
 
   if (callContext.isElementalProcWithArrayArgs()) {
@@ -1898,7 +1690,7 @@ genProcedureRef(CallContext &callContext) {
                                          callContext.converter);
   mlir::FunctionType callSiteType = caller.genFunctionType();
 
-  PreparedActualArguments loweredActuals;
+  Fortran::lower::PreparedActualArguments loweredActuals;
   // Lower the actual arguments
   for (const Fortran::lower::CallInterface<
            Fortran::lower::CallerInterface>::PassedEntity &arg :
@@ -1933,7 +1725,7 @@ genProcedureRef(CallContext &callContext) {
                 Fortran::lower::CallerInterface::PassEntityBy::MutableBox);
 
       loweredActuals.emplace_back(
-          PreparedActualArgument{loweredActual, isPresent});
+          Fortran::lower::PreparedActualArgument{loweredActual, isPresent});
     } else {
       // Optional dummy argument for which there is no actual argument.
       loweredActuals.emplace_back(std::nullopt);

diff  --git a/flang/lib/Lower/HlfirIntrinsics.cpp b/flang/lib/Lower/HlfirIntrinsics.cpp
new file mode 100644
index 00000000000000..89ca1c4114d2e7
--- /dev/null
+++ b/flang/lib/Lower/HlfirIntrinsics.cpp
@@ -0,0 +1,296 @@
+//===-- HlfirIntrinsics.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Lower/HlfirIntrinsics.h"
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "flang/Optimizer/Builder/IntrinsicCall.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVector.h"
+#include <mlir/IR/ValueRange.h>
+
+namespace {
+
+class HlfirTransformationalIntrinsic {
+public:
+  explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
+                                          mlir::Location loc)
+      : builder(builder), loc(loc) {}
+
+  virtual ~HlfirTransformationalIntrinsic() = default;
+
+  hlfir::EntityWithAttributes
+  lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
+        const fir::IntrinsicArgumentLoweringRules *argLowering,
+        mlir::Type stmtResultType) {
+    mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
+    return {hlfir::EntityWithAttributes{res}};
+  }
+
+protected:
+  fir::FirOpBuilder &builder;
+  mlir::Location loc;
+
+  virtual mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) = 0;
+
+  llvm::SmallVector<mlir::Value> getOperandVector(
+      const Fortran::lower::PreparedActualArguments &loweredActuals,
+      const fir::IntrinsicArgumentLoweringRules *argLowering);
+
+  mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
+
+  template <typename OP, typename... BUILD_ARGS>
+  inline OP createOp(BUILD_ARGS... args) {
+    return builder.create<OP>(loc, args...);
+  }
+};
+
+template <typename OP, bool HAS_MASK>
+class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
+public:
+  using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+};
+using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
+using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
+using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
+using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
+
+template <typename OP>
+class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
+public:
+  using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+};
+using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
+using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
+
+class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
+public:
+  using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+};
+
+class HlfirCountLowering : public HlfirTransformationalIntrinsic {
+public:
+  using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+};
+
+class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
+public:
+  HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
+                            hlfir::CharExtremumPredicate pred)
+      : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+
+protected:
+  hlfir::CharExtremumPredicate pred;
+};
+
+} // namespace
+
+llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering) {
+  llvm::SmallVector<mlir::Value> operands;
+  operands.reserve(loweredActuals.size());
+
+  for (size_t i = 0; i < loweredActuals.size(); ++i) {
+    std::optional<Fortran::lower::PreparedActualArgument> arg =
+        loweredActuals[i];
+    if (!arg) {
+      operands.emplace_back();
+      continue;
+    }
+    hlfir::Entity actual = arg->getOriginalActual();
+    mlir::Value valArg;
+
+    if (!argLowering) {
+      valArg = hlfir::loadTrivialScalar(loc, builder, actual);
+    } else {
+      fir::ArgLoweringRule argRules =
+          fir::lowerIntrinsicArgumentAs(*argLowering, i);
+      if (!argRules.handleDynamicOptional &&
+          argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
+        valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
+      else
+        valArg = actual.getBase();
+    }
+
+    operands.emplace_back(valArg);
+  }
+  return operands;
+}
+
+mlir::Type
+HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
+                                                  mlir::Type stmtResultType) {
+  mlir::Type normalisedResult =
+      hlfir::getFortranElementOrSequenceType(stmtResultType);
+  if (auto array = normalisedResult.dyn_cast<fir::SequenceType>()) {
+    hlfir::ExprType::Shape resultShape =
+        hlfir::ExprType::Shape{array.getShape()};
+    mlir::Type elementType = array.getEleTy();
+    return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
+                                /*polymorphic=*/false);
+  }
+  return normalisedResult;
+}
+
+template <typename OP, bool HAS_MASK>
+mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  mlir::Value array = operands[0];
+  mlir::Value dim = operands[1];
+  // dim, mask can be NULL if these arguments are not given
+  if (dim)
+    dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+
+  mlir::Type resultTy = computeResultType(array, stmtResultType);
+
+  OP op;
+  if constexpr (HAS_MASK)
+    op = createOp<OP>(resultTy, array, dim, /*mask=*/operands[2]);
+  else
+    op = createOp<OP>(resultTy, array, dim);
+  return op;
+}
+
+template <typename OP>
+mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  mlir::Type resultType = computeResultType(operands[0], stmtResultType);
+  return createOp<OP>(resultType, operands[0], operands[1]);
+}
+
+mlir::Value HlfirTransposeLowering::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  hlfir::ExprType::Shape resultShape;
+  mlir::Type normalisedResult =
+      hlfir::getFortranElementOrSequenceType(stmtResultType);
+  auto array = normalisedResult.cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> arrayShape = array.getShape();
+  assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
+  mlir::Type elementType = array.getEleTy();
+  resultShape.push_back(arrayShape[0]);
+  resultShape.push_back(arrayShape[1]);
+  mlir::Type resultTy = hlfir::ExprType::get(
+      builder.getContext(), resultShape, elementType, /*polymorphic=*/false);
+  return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
+}
+
+mlir::Value HlfirCountLowering::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  mlir::Value array = operands[0];
+  mlir::Value dim = operands[1];
+  if (dim)
+    dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+  mlir::Value kind = operands[2];
+  mlir::Type resultType = computeResultType(array, stmtResultType);
+  return createOp<hlfir::CountOp>(resultType, array, dim, kind);
+}
+
+mlir::Value HlfirCharExtremumLowering::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  assert(operands.size() >= 2);
+  return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
+}
+
+std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
+    fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  if (name == "sum")
+    return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                stmtResultType);
+  if (name == "product")
+    return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                    stmtResultType);
+  if (name == "any")
+    return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                stmtResultType);
+  if (name == "all")
+    return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                stmtResultType);
+  if (name == "matmul")
+    return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                   stmtResultType);
+  if (name == "dot_product")
+    return HlfirDotProductLowering{builder, loc}.lower(
+        loweredActuals, argLowering, stmtResultType);
+  if (name == "transpose")
+    return HlfirTransposeLowering{builder, loc}.lower(
+        loweredActuals, argLowering, stmtResultType);
+  if (name == "count")
+    return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                  stmtResultType);
+  if (mlir::isa<fir::CharacterType>(stmtResultType)) {
+    if (name == "min")
+      return HlfirCharExtremumLowering{builder, loc,
+                                       hlfir::CharExtremumPredicate::min}
+          .lower(loweredActuals, argLowering, stmtResultType);
+    if (name == "max")
+      return HlfirCharExtremumLowering{builder, loc,
+                                       hlfir::CharExtremumPredicate::max}
+          .lower(loweredActuals, argLowering, stmtResultType);
+  }
+  return std::nullopt;
+}


        


More information about the flang-commits mailing list