[flang-commits] [flang] 9facbb6 - [flang] lower sum intrinsic to hlfir.sum operation

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Mon Feb 13 02:52:16 PST 2023


Author: Tom Eccles
Date: 2023-02-13T10:50:11Z
New Revision: 9facbb694250ef8d0144629fb07e170162733dea

URL: https://github.com/llvm/llvm-project/commit/9facbb694250ef8d0144629fb07e170162733dea
DIFF: https://github.com/llvm/llvm-project/commit/9facbb694250ef8d0144629fb07e170162733dea.diff

LOG: [flang] lower sum intrinsic to hlfir.sum operation

Differential Revision: https://reviews.llvm.org/D142898

Added: 
    flang/test/Lower/HLFIR/sum.f90

Modified: 
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Lower/ConvertCall.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/Lower/HLFIR/expr-box.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 1321eab11041e..d929850d9ff28 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -280,6 +280,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
     $array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
   }];
 
+  // dim and mask can be NULL, array must not be.
+  let builders = [OpBuilder<(ins "mlir::Value":$array,
+                                 "mlir::Value":$dim,
+                                 "mlir::Value":$mask,
+                                 "mlir::Type":$resultType)>];
+
   let hasVerifier = 1;
 }
 

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 5a9a244a45f10..5978441517df4 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
@@ -26,11 +27,16 @@
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "flang-lower-expr"
 
+static llvm::cl::opt<bool> useHlfirIntrinsicOps(
+    "use-hlfir-intrinsic-ops", llvm::cl::init(true),
+    llvm::cl::desc("Lower via HLFIR transformational intrinsic operations such as hlfir.sum"));
+
 /// Helper to package a Value and its properties into an ExtendedValue.
 static fir::ExtendedValue toExtendedValue(mlir::Location loc, mlir::Value base,
                                           llvm::ArrayRef<mlir::Value> extents,
@@ -631,7 +637,8 @@ extendedValueToHlfirEntity(mlir::Location loc, fir::FirOpBuilder &builder,
                            const fir::ExtendedValue &exv,
                            llvm::StringRef name) {
   mlir::Value firBase = fir::getBase(exv);
-  if (fir::isa_trivial(firBase.getType()))
+  mlir::Type firBaseTy = firBase.getType();
+  if (fir::isa_trivial(firBaseTy))
     return hlfir::EntityWithAttributes{firBase};
   if (auto charTy = firBase.getType().dyn_cast<fir::CharacterType>()) {
     // CHAR() intrinsic and BIND(C) procedures returning CHARACTER(1)
@@ -1232,6 +1239,56 @@ genIntrinsicRefCore(PreparedActualArguments &loweredActuals,
   return resultEntity;
 }
 
+/// Lower calls to intrinsic procedures with actual arguments that have been
+/// pre-lowered but have not yet been prepared according to the interface.
+static std::optional<hlfir::EntityWithAttributes>
+genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals,
+                         const Fortran::evaluate::SpecificIntrinsic &intrinsic,
+                         const fir::IntrinsicArgumentLoweringRules *argLowering,
+                         CallContext &callContext) {
+  if (!useHlfirIntrinsicOps)
+    return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext);
+
+  auto getOperandVector =
+    [](PreparedActualArguments &loweredActuals) {
+      llvm::SmallVector<mlir::Value> operands;
+      operands.reserve(loweredActuals.size());
+      for (auto arg : llvm::enumerate(loweredActuals)) {
+        if (!arg.value()) {
+          operands.emplace_back();
+          continue;
+        }
+        hlfir::Entity actual = arg.value()->getOriginalActual();
+        operands.emplace_back(actual.getBase());
+      }
+      return operands;
+    };
+
+  fir::FirOpBuilder &builder = callContext.getBuilder();
+  mlir::Location loc = callContext.loc;
+
+  if (intrinsic.name == "sum") {
+    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
+    assert(operands.size() == 3);
+    mlir::Value array = hlfir::derefPointersAndAllocatables(
+        loc, builder, hlfir::Entity{operands[0]});
+    mlir::Value dim = operands[1];
+    if (dim)
+      dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+    mlir::Value mask = operands[2];
+    // dim, mask can be NULL if these arguments were not given
+    hlfir::SumOp sumOp = builder.create<hlfir::SumOp>(loc, array, dim, mask,
+                                                      *callContext.resultType);
+    return {hlfir::EntityWithAttributes{sumOp.getResult()}};
+  }
+
+  // TODO add hlfir operations for other transformational intrinsics here
+
+  // fallback to calling the intrinsic via fir.call
+  return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
+                             callContext);
+}
+
 namespace {
 template <typename ElementalCallBuilderImpl>
 class ElementalCallBuilder {
@@ -1405,8 +1462,8 @@ class ElementalIntrinsicCallBuilder
   std::optional<hlfir::Entity>
   genElementalKernel(PreparedActualArguments &loweredActuals,
                      CallContext &callContext) {
-    return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
-                               callContext);
+    return genHLFIRIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
+                                    callContext);
   }
   // Elemental intrinsic functions cannot modify their arguments.
   bool argMayBeModifiedByCall(int) const { return !isFunction; }
@@ -1512,8 +1569,8 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic &intrinsic,
         .genElementalCall(loweredActuals, /*isImpure=*/!isFunction, callContext)
         .value();
   }
-  std::optional<hlfir::EntityWithAttributes> result =
-      genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext);
+  std::optional<hlfir::EntityWithAttributes> result = genHLFIRIntrinsicRefCore(
+      loweredActuals, intrinsic, argLowering, callContext);
   if (result && result->getType().isa<hlfir::ExprType>()) {
     fir::FirOpBuilder *bldr = &callContext.getBuilder();
     callContext.stmtCtx.attachCleanup(

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index feea448b3d4c9..cd24b0f5079e3 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -489,6 +489,28 @@ mlir::LogicalResult hlfir::SumOp::verify() {
   return mlir::success();
 }
 
+void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
+                         mlir::Value array, mlir::Value dim, mlir::Value mask,
+                         mlir::Type stmtResultType) {
+  assert(array && "array argument is not optional");
+
+  fir::SequenceType arrayTy =
+      hlfir::getFortranElementOrSequenceType(array.getType())
+          .dyn_cast<fir::SequenceType>();
+  assert(arrayTy && "array must be of array type");
+  mlir::Type numTy = arrayTy.getEleTy();
+
+  // get the result shape from the statement context
+  hlfir::ExprType::Shape resultShape;
+  if (auto array = stmtResultType.dyn_cast<fir::SequenceType>()) {
+    resultShape = hlfir::ExprType::Shape{array.getShape()};
+  }
+  mlir::Type resultType = hlfir::ExprType::get(
+      builder.getContext(), resultShape, numTy, /*polymorphic=*/false);
+
+  build(builder, result, resultType, array, dim, mask);
+}
+
 //===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/HLFIR/expr-box.f90 b/flang/test/Lower/HLFIR/expr-box.f90
index 330e857aa9b76..e42c01a8f91d9 100644
--- a/flang/test/Lower/HLFIR/expr-box.f90
+++ b/flang/test/Lower/HLFIR/expr-box.f90
@@ -27,12 +27,12 @@ subroutine test_place_in_memory_and_embox()
 ! CHECK:  fir.call @_FortranAioOutputDescriptor(%{{.*}}, %[[BOX_CAST]])
 
 ! check we can box a trivial value
-subroutine sumMask(s, a)
+subroutine productMask(s, a)
   integer :: s
   integer :: a(:)
-  s = sum(a, mask=.true.)
+  s = product(a, mask=.true.)
 endsubroutine
-! CHECK-LABEL: func.func @_QPsummask(
+! CHECK-LABEL: func.func @_QPproductmask(
 ! CHECK:      %[[TRUE:.*]] = arith.constant true
 ! CHECK:      %[[ALLOC:.*]] = fir.alloca !fir.logical<4>
 ! CHECK:      %[[TRUE_L4:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>

diff  --git a/flang/test/Lower/HLFIR/sum.f90 b/flang/test/Lower/HLFIR/sum.f90
new file mode 100644
index 0000000000000..32a2423da698d
--- /dev/null
+++ b/flang/test/Lower/HLFIR/sum.f90
@@ -0,0 +1,112 @@
+! Test lowering of SUM intrinsic to HLFIR
+! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s
+
+! simple 1 argument SUM
+subroutine sum1(a, s)
+  integer :: a(:), s
+  s = SUM(a)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum1(
+! CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>) ->
+!hlfir.expr<i32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }
+
+! sum with by-ref DIM argument
+subroutine sum2(a, s, d)
+  integer :: a(:,:), s(:), d
+  s = SUM(a, d)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum2(
+! CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.ref<i32>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG:     %[[DIM_REF:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT:    %[[DIM:.*]] = fir.load %[[DIM_REF]]#0 : !fir.ref<i32>
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[DIM]] {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<?xi32>, !fir.box<!fir.array<?xi32>>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }
+
+! sum with scalar mask argument
+subroutine sum3(a, s, m)
+  integer :: a(:), s
+  logical :: m
+  s = SUM(a, m)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum3(
+! CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.ref<!fir.logical<4>>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 mask %[[MASK]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.logical<4>>) -> !hlfir.expr<i32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }
+
+! sum with array mask argument
+subroutine sum4(a, s, m)
+  integer :: a(:), s
+  logical :: m(:)
+  s = SUM(a, m)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum4(
+! CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 mask %[[MASK]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }
+
+! sum with all 3 arguments, dim is by-val, array isn't boxed
+subroutine sum5(s)
+  integer :: s(2)
+  integer :: a(2,2) = reshape((/1, 2, 3, 4/), [2,2])
+  s = sum(a, 1, .true.)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum5
+! CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<2xi32>>
+! CHECK-DAG:     %[[ADDR:.*]] = fir.address_of({{.*}}) : !fir.ref<!fir.array<2x2xi32>>
+! CHECK-DAG:     %[[ARRAY_SHAPE:.*]] = fir.shape {{.*}} -> !fir.shape<2>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ADDR]](%[[ARRAY_SHAPE]])
+! CHECK-DAG:     %[[OUT_SHAPE:.*]] = fir.shape {{.*}} -> !fir.shape<1>
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG0]](%[[OUT_SHAPE]])
+! CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
+! CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[C1]] mask %[[TRUE]] {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<2x2xi32>>, i32, i1) -> !hlfir.expr<2xi32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<2xi32>, !fir.ref<!fir.array<2xi32>>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]] : !hlfir.expr<2xi32>
+! CHECK-NEXT:    return
+! CHECK-nEXT:  }
+
+subroutine sum6(a, s, d)
+  integer, pointer :: d
+  real :: a(:,:), s(:)
+  s = sum(a, (d))
+end subroutine
+! CHECK-LABEL: func.func @_QPsum6(
+! CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?xf32>>
+! CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?xf32>>
+! CHECK:           %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>
+! CHECK-DAG:     %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG:     %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG:     %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT:     %[[DIM_BOX:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+! CHECK-NEXT:    %[[DIM_ADDR:.*]] = fir.box_addr %[[DIM_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
+! CHECK-NEXT:    %[[DIM0:.*]] = fir.load %[[DIM_ADDR]] : !fir.ptr<i32>
+! CHECK-NEXT:    %[[DIM1:.*]] = hlfir.no_reassoc %[[DIM0]] : i32
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[DIM1]] {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xf32>>, i32) -> !hlfir.expr<?xf32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }


        


More information about the flang-commits mailing list