[flang-commits] [flang] 2b13856 - [flang] Support more data types for reduction in SimplifyIntrinsicsPass.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Mon Sep 19 12:17:46 PDT 2022
Author: Slava Zakharin
Date: 2022-09-19T12:16:22-07:00
New Revision: 2b138567e0cb126ce2bd726e9c4becb69aed0563
URL: https://github.com/llvm/llvm-project/commit/2b138567e0cb126ce2bd726e9c4becb69aed0563
DIFF: https://github.com/llvm/llvm-project/commit/2b138567e0cb126ce2bd726e9c4becb69aed0563.diff
LOG: [flang] Support more data types for reduction in SimplifyIntrinsicsPass.
All floating point (not complex) and integer types should be supported now.
Differential Revision: https://reviews.llvm.org/D133818
Added:
Modified:
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
flang/test/Transforms/simplifyintrinsics.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index a6887bf84fd35..d23736ef8a68e 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -180,10 +180,12 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder,
// end function RTNAME(Sum)<T>_simplified
auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
- return elementType.isa<mlir::FloatType>()
- ? builder.createRealConstant(loc, elementType,
- llvm::APFloat(0.0))
- : builder.createIntegerConstant(loc, elementType, 0);
+ if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(loc, elementType,
+ llvm::APFloat::getZero(sem));
+ }
+ return builder.createIntegerConstant(loc, elementType, 0);
};
auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
@@ -464,17 +466,22 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank == 1) {
mlir::Location loc = call.getLoc();
- mlir::Type type;
fir::FirOpBuilder builder(call, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
+
+ // Support only floating point and integer results now.
+ mlir::Type resultType = call.getResult(0).getType();
+ if (!resultType.isa<mlir::FloatType>() &&
+ !resultType.isa<mlir::IntegerType>())
return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
+
+ auto argType = getArgElementType(args[0]);
+ if (!argType)
+ return;
+ assert(*argType == resultType &&
+ "Argument/result types mismatch in reduction");
+
+ auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
+ return genNoneBoxType(builder, resultType);
};
mlir::func::FuncOp newFunc =
getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index 0d652878f6524..b5d24c5785243 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -153,6 +153,65 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// -----
+// Call to SUM with 1D F32 is replaced.
+module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} {
+ func.func @sum_1d_real(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "a"}) -> f32 {
+ %c10 = arith.constant 10 : index
+ %0 = fir.alloca f32 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"}
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
+ %3 = fir.absent !fir.box<i1>
+ %c0 = arith.constant 0 : index
+ %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref<!fir.char<1,13>>
+ %c5_i32 = arith.constant 5 : i32
+ %5 = fir.convert %2 : (!fir.box<!fir.array<10xf32>>) -> !fir.box<none>
+ %6 = fir.convert %4 : (!fir.ref<!fir.char<1,13>>) -> !fir.ref<i8>
+ %7 = fir.convert %c0 : (index) -> i32
+ %8 = fir.convert %3 : (!fir.box<i1>) -> !fir.box<none>
+ %9 = fir.call @_FortranASumReal4(%5, %6, %c5_i32, %7, %8) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f32
+ fir.store %9 to %0 : !fir.ref<f32>
+ %10 = fir.load %0 : !fir.ref<f32>
+ return %10 : f32
+ }
+ func.func private @_FortranASumReal4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f32 attributes {fir.runtime}
+ fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> {
+ %0 = fir.string_lit "./isum_5.f90\00"(13) : !fir.char<1,13>
+ fir.has_value %0 : !fir.char<1,13>
+ }
+}
+
+
+// CHECK-LABEL: func.func @sum_1d_real(
+// CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "a"}) -> f32 {
+// CHECK: %[[CINDEX_10:.*]] = arith.constant 10 : index
+// CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1>
+// CHECK: %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
+// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box<!fir.array<10xf32>>) -> !fir.box<none>
+// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
+// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
+// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
+// CHECK: return %{{.*}} : f32
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_FortranASumReal4_simplified(
+// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) {
+// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f32>
+// CHECK: %[[NEW_SUM:.*]] = arith.addf %[[ITEM_VAL]], %[[SUM]] : f32
+// CHECK: fir.result %[[NEW_SUM]] : f32
+// CHECK: }
+// CHECK: return %[[RES]] : f32
+// CHECK: }
+
+// -----
+
// Call to SUM with 1D COMPLEX array is not replaced.
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} {
func.func @sum_1d_complex(%arg0: !fir.ref<!fir.array<10x!fir.complex<4>>> {fir.bindc_name = "a"}) -> !fir.complex<4> {
More information about the flang-commits
mailing list