[flang-commits] [flang] 9facbb6 - [flang] lower sum intrinsic to hlfir.sum operation
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Mon Feb 13 02:52:16 PST 2023
Author: Tom Eccles
Date: 2023-02-13T10:50:11Z
New Revision: 9facbb694250ef8d0144629fb07e170162733dea
URL: https://github.com/llvm/llvm-project/commit/9facbb694250ef8d0144629fb07e170162733dea
DIFF: https://github.com/llvm/llvm-project/commit/9facbb694250ef8d0144629fb07e170162733dea.diff
LOG: [flang] lower sum intrinsic to hlfir.sum operation
Differential Revision: https://reviews.llvm.org/D142898
Added:
flang/test/Lower/HLFIR/sum.f90
Modified:
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Lower/ConvertCall.cpp
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/Lower/HLFIR/expr-box.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 1321eab11041e..d929850d9ff28 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -280,6 +280,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
$array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
}];
+ // dim and mask can be NULL, array must not be.
+ let builders = [OpBuilder<(ins "mlir::Value":$array,
+ "mlir::Value":$dim,
+ "mlir::Value":$mask,
+ "mlir::Type":$resultType)>];
+
let hasVerifier = 1;
}
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 5a9a244a45f10..5978441517df4 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -19,6 +19,7 @@
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
#include "flang/Optimizer/Builder/MutableBox.h"
@@ -26,11 +27,16 @@
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "flang-lower-expr"
+static llvm::cl::opt<bool> useHlfirIntrinsicOps(
+ "use-hlfir-intrinsic-ops", llvm::cl::init(true),
+ llvm::cl::desc("Lower via HLFIR transformational intrinsic operations such as hlfir.sum"));
+
/// Helper to package a Value and its properties into an ExtendedValue.
static fir::ExtendedValue toExtendedValue(mlir::Location loc, mlir::Value base,
llvm::ArrayRef<mlir::Value> extents,
@@ -631,7 +637,8 @@ extendedValueToHlfirEntity(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv,
llvm::StringRef name) {
mlir::Value firBase = fir::getBase(exv);
- if (fir::isa_trivial(firBase.getType()))
+ mlir::Type firBaseTy = firBase.getType();
+ if (fir::isa_trivial(firBaseTy))
return hlfir::EntityWithAttributes{firBase};
if (auto charTy = firBase.getType().dyn_cast<fir::CharacterType>()) {
// CHAR() intrinsic and BIND(C) procedures returning CHARACTER(1)
@@ -1232,6 +1239,56 @@ genIntrinsicRefCore(PreparedActualArguments &loweredActuals,
return resultEntity;
}
+/// Lower calls to intrinsic procedures with actual arguments that have been
+/// pre-lowered but have not yet been prepared according to the interface.
+static std::optional<hlfir::EntityWithAttributes>
+genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals,
+ const Fortran::evaluate::SpecificIntrinsic &intrinsic,
+ const fir::IntrinsicArgumentLoweringRules *argLowering,
+ CallContext &callContext) {
+ if (!useHlfirIntrinsicOps)
+ return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext);
+
+ auto getOperandVector =
+ [](PreparedActualArguments &loweredActuals) {
+ llvm::SmallVector<mlir::Value> operands;
+ operands.reserve(loweredActuals.size());
+ for (auto arg : llvm::enumerate(loweredActuals)) {
+ if (!arg.value()) {
+ operands.emplace_back();
+ continue;
+ }
+ hlfir::Entity actual = arg.value()->getOriginalActual();
+ operands.emplace_back(actual.getBase());
+ }
+ return operands;
+ };
+
+ fir::FirOpBuilder &builder = callContext.getBuilder();
+ mlir::Location loc = callContext.loc;
+
+ if (intrinsic.name == "sum") {
+ llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
+ assert(operands.size() == 3);
+ mlir::Value array = hlfir::derefPointersAndAllocatables(
+ loc, builder, hlfir::Entity{operands[0]});
+ mlir::Value dim = operands[1];
+ if (dim)
+ dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+ mlir::Value mask = operands[2];
+ // dim, mask can be NULL if these arguments were not given
+ hlfir::SumOp sumOp = builder.create<hlfir::SumOp>(loc, array, dim, mask,
+ *callContext.resultType);
+ return {hlfir::EntityWithAttributes{sumOp.getResult()}};
+ }
+
+ // TODO add hlfir operations for other transformational intrinsics here
+
+ // fallback to calling the intrinsic via fir.call
+ return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
+ callContext);
+}
+
namespace {
template <typename ElementalCallBuilderImpl>
class ElementalCallBuilder {
@@ -1405,8 +1462,8 @@ class ElementalIntrinsicCallBuilder
std::optional<hlfir::Entity>
genElementalKernel(PreparedActualArguments &loweredActuals,
CallContext &callContext) {
- return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
- callContext);
+ return genHLFIRIntrinsicRefCore(loweredActuals, intrinsic, argLowering,
+ callContext);
}
// Elemental intrinsic functions cannot modify their arguments.
bool argMayBeModifiedByCall(int) const { return !isFunction; }
@@ -1512,8 +1569,8 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic &intrinsic,
.genElementalCall(loweredActuals, /*isImpure=*/!isFunction, callContext)
.value();
}
- std::optional<hlfir::EntityWithAttributes> result =
- genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext);
+ std::optional<hlfir::EntityWithAttributes> result = genHLFIRIntrinsicRefCore(
+ loweredActuals, intrinsic, argLowering, callContext);
if (result && result->getType().isa<hlfir::ExprType>()) {
fir::FirOpBuilder *bldr = &callContext.getBuilder();
callContext.stmtCtx.attachCleanup(
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index feea448b3d4c9..cd24b0f5079e3 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -489,6 +489,28 @@ mlir::LogicalResult hlfir::SumOp::verify() {
return mlir::success();
}
+void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
+ mlir::Value array, mlir::Value dim, mlir::Value mask,
+ mlir::Type stmtResultType) {
+ assert(array && "array argument is not optional");
+
+ fir::SequenceType arrayTy =
+ hlfir::getFortranElementOrSequenceType(array.getType())
+ .dyn_cast<fir::SequenceType>();
+ assert(arrayTy && "array must be of array type");
+ mlir::Type numTy = arrayTy.getEleTy();
+
+ // get the result shape from the statement context
+ hlfir::ExprType::Shape resultShape;
+ if (auto array = stmtResultType.dyn_cast<fir::SequenceType>()) {
+ resultShape = hlfir::ExprType::Shape{array.getShape()};
+ }
+ mlir::Type resultType = hlfir::ExprType::get(
+ builder.getContext(), resultShape, numTy, /*polymorphic=*/false);
+
+ build(builder, result, resultType, array, dim, mask);
+}
+
//===----------------------------------------------------------------------===//
// AssociateOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/HLFIR/expr-box.f90 b/flang/test/Lower/HLFIR/expr-box.f90
index 330e857aa9b76..e42c01a8f91d9 100644
--- a/flang/test/Lower/HLFIR/expr-box.f90
+++ b/flang/test/Lower/HLFIR/expr-box.f90
@@ -27,12 +27,12 @@ subroutine test_place_in_memory_and_embox()
! CHECK: fir.call @_FortranAioOutputDescriptor(%{{.*}}, %[[BOX_CAST]])
! check we can box a trivial value
-subroutine sumMask(s, a)
+subroutine productMask(s, a)
integer :: s
integer :: a(:)
- s = sum(a, mask=.true.)
+ s = product(a, mask=.true.)
endsubroutine
-! CHECK-LABEL: func.func @_QPsummask(
+! CHECK-LABEL: func.func @_QPproductmask(
! CHECK: %[[TRUE:.*]] = arith.constant true
! CHECK: %[[ALLOC:.*]] = fir.alloca !fir.logical<4>
! CHECK: %[[TRUE_L4:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
diff --git a/flang/test/Lower/HLFIR/sum.f90 b/flang/test/Lower/HLFIR/sum.f90
new file mode 100644
index 0000000000000..32a2423da698d
--- /dev/null
+++ b/flang/test/Lower/HLFIR/sum.f90
@@ -0,0 +1,112 @@
+! Test lowering of SUM intrinsic to HLFIR
+! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s
+
+! simple 1 argument SUM
+subroutine sum1(a, s)
+ integer :: a(:), s
+ s = SUM(a)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum1(
+! CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>) ->
+!hlfir.expr<i32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]]
+! CHECK-NEXT: return
+! CHECK-NEXT: }
+
+! sum with by-ref DIM argument
+subroutine sum2(a, s, d)
+ integer :: a(:,:), s(:), d
+ s = SUM(a, d)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum2(
+! CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.ref<i32>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG: %[[DIM_REF:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT: %[[DIM:.*]] = fir.load %[[DIM_REF]]#0 : !fir.ref<i32>
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[DIM]] {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<?xi32>, !fir.box<!fir.array<?xi32>>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]]
+! CHECK-NEXT: return
+! CHECK-NEXT: }
+
+! sum with scalar mask argument
+subroutine sum3(a, s, m)
+ integer :: a(:), s
+ logical :: m
+ s = SUM(a, m)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum3(
+! CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.ref<!fir.logical<4>>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 mask %[[MASK]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.logical<4>>) -> !hlfir.expr<i32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]]
+! CHECK-NEXT: return
+! CHECK-NEXT: }
+
+! sum with array mask argument
+subroutine sum4(a, s, m)
+ integer :: a(:), s
+ logical :: m(:)
+ s = SUM(a, m)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum4(
+! CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32> {fir.bindc_name = "s"}, %[[ARG2:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 mask %[[MASK]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<i32>, !fir.ref<i32>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]]
+! CHECK-NEXT: return
+! CHECK-NEXT: }
+
+! sum with all 3 arguments, dim is by-val, array isn't boxed
+subroutine sum5(s)
+ integer :: s(2)
+ integer :: a(2,2) = reshape((/1, 2, 3, 4/), [2,2])
+ s = sum(a, 1, .true.)
+end subroutine
+! CHECK-LABEL: func.func @_QPsum5
+! CHECK: %[[ARG0:.*]]: !fir.ref<!fir.array<2xi32>>
+! CHECK-DAG: %[[ADDR:.*]] = fir.address_of({{.*}}) : !fir.ref<!fir.array<2x2xi32>>
+! CHECK-DAG: %[[ARRAY_SHAPE:.*]] = fir.shape {{.*}} -> !fir.shape<2>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ADDR]](%[[ARRAY_SHAPE]])
+! CHECK-DAG: %[[OUT_SHAPE:.*]] = fir.shape {{.*}} -> !fir.shape<1>
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG0]](%[[OUT_SHAPE]])
+! CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+! CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[C1]] mask %[[TRUE]] {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<2x2xi32>>, i32, i1) -> !hlfir.expr<2xi32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<2xi32>, !fir.ref<!fir.array<2xi32>>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]] : !hlfir.expr<2xi32>
+! CHECK-NEXT: return
+! CHECK-nEXT: }
+
+subroutine sum6(a, s, d)
+ integer, pointer :: d
+ real :: a(:,:), s(:)
+ s = sum(a, (d))
+end subroutine
+! CHECK-LABEL: func.func @_QPsum6(
+! CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?x?xf32>>
+! CHECK: %[[ARG1:.*]]: !fir.box<!fir.array<?xf32>>
+! CHECK: %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>
+! CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]]
+! CHECK-DAG: %[[OUT:.*]]:2 = hlfir.declare %[[ARG1]]
+! CHECK-DAG: %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+! CHECK-NEXT: %[[DIM_BOX:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+! CHECK-NEXT: %[[DIM_ADDR:.*]] = fir.box_addr %[[DIM_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
+! CHECK-NEXT: %[[DIM0:.*]] = fir.load %[[DIM_ADDR]] : !fir.ptr<i32>
+! CHECK-NEXT: %[[DIM1:.*]] = hlfir.no_reassoc %[[DIM0]] : i32
+! CHECK-NEXT: %[[EXPR:.*]] = hlfir.sum %[[ARRAY]]#0 dim %[[DIM1]] {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xf32>>, i32) -> !hlfir.expr<?xf32>
+! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[OUT]]#0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK-NEXT: hlfir.destroy %[[EXPR]]
+! CHECK-NEXT: return
+! CHECK-NEXT: }
More information about the flang-commits
mailing list