[flang-commits] [flang] ffe1661 - [flang] Propagate fastmath flags during intrinsics simplification.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Thu Nov 17 10:17:00 PST 2022


Author: Slava Zakharin
Date: 2022-11-17T10:16:47-08:00
New Revision: ffe1661fabc9cf379a10a0bf15268c6549e4836f

URL: https://github.com/llvm/llvm-project/commit/ffe1661fabc9cf379a10a0bf15268c6549e4836f
DIFF: https://github.com/llvm/llvm-project/commit/ffe1661fabc9cf379a10a0bf15268c6549e4836f.diff

LOG: [flang] Propagate fastmath flags during intrinsics simplification.

In general, the meaning of fastmath flags on a call during inlining
is that the call's operation flags must be ignored. For user functions
that means that the fastmath flags used for the function definition
override any call site's fastmath flags. For intrinsic functions
we can use the call site's fastmath flags, but we have to make sure
that the call sites with different flags produce/use different
simplified versions of the same intrinsic function.

Differential Revision: https://reviews.llvm.org/D138048

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
    flang/test/Transforms/simplifyintrinsics.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index f6b795515ecc2..560b991ab53e6 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -419,6 +419,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// config.
   void setFastMathFlags(Fortran::common::MathOptionsBase options);
 
+  /// Get current FastMathFlags value.
+  mlir::arith::FastMathFlags getFastMathFlags() const { return fastMathFlags; }
+
   /// Dump the current function. (debug)
   LLVM_DUMP_METHOD void dumpFunc();
 

diff  --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index cdc3ab9393f3b..f74fe35bc3af0 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -85,6 +85,35 @@ class SimplifyIntrinsicsPass
 
 } // namespace
 
+/// Create FirOpBuilder with the provided \p op insertion point
+/// and \p kindMap additionally inheriting FastMathFlags from \p op.
+static fir::FirOpBuilder
+getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
+  fir::FirOpBuilder builder{op, kindMap};
+  auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
+  if (!fmi)
+    return builder;
+
+  // Regardless of what default FastMathFlags are used by FirOpBuilder,
+  // override them with FastMathFlags attached to the operation.
+  builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
+  return builder;
+}
+
+/// Stringify FastMathFlags set for the given \p builder in a way
+/// that the string may be used for mangling a function name.
+/// If FastMathFlags are set to 'none', then the result is an empty
+/// string.
+static std::string getFastMathFlagsString(const fir::FirOpBuilder &builder) {
+  mlir::arith::FastMathFlags flags = builder.getFastMathFlags();
+  if (flags == mlir::arith::FastMathFlags::none)
+    return {};
+
+  std::string fmfString{mlir::arith::stringifyFastMathFlags(flags)};
+  std::replace(fmfString.begin(), fmfString.end(), ',', '_');
+  return fmfString;
+}
+
 /// Generate function type for the simplified version of RTNAME(Sum) and
 /// similar functions with a fir.box<none> type returning \p elementType.
 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
@@ -511,7 +540,8 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
   unsigned rank = getDimCount(args[0]);
   if (dimAndMaskAbsent && rank > 0) {
     mlir::Location loc = call.getLoc();
-    fir::FirOpBuilder builder(call, kindMap);
+    fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
+    std::string fmfString{getFastMathFlagsString(builder)};
 
     // Support only floating point and integer results now.
     mlir::Type resultType = call.getResult(0).getType();
@@ -535,7 +565,10 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
     // Mangle the function name with the rank value as "x<rank>".
     std::string funcName =
         (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
-         mlir::Twine{rank})
+         mlir::Twine{rank} +
+         // We must mangle the generated function name with FastMathFlags
+         // value.
+         (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
             .str();
     mlir::func::FuncOp newFunc =
         getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
@@ -576,7 +609,10 @@ void SimplifyIntrinsicsPass::runOnOperation() {
           const mlir::Value &v1 = args[0];
           const mlir::Value &v2 = args[1];
           mlir::Location loc = call.getLoc();
-          fir::FirOpBuilder builder(op, kindMap);
+          fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
+          // Stringize the builder's FastMathFlags flags for mangling
+          // the generated function name.
+          std::string fmfString{getFastMathFlagsString(builder)};
 
           mlir::Type type = call.getResult(0).getType();
           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
@@ -611,9 +647,13 @@ void SimplifyIntrinsicsPass::runOnOperation() {
           // of the arguments.
           std::string typedFuncName(funcName);
           llvm::raw_string_ostream nameOS(typedFuncName);
-          nameOS << "_";
+          // We must mangle the generated function name with FastMathFlags
+          // value.
+          if (!fmfString.empty())
+            nameOS << '_' << fmfString;
+          nameOS << '_';
           arg1Type->print(nameOS);
-          nameOS << "_";
+          nameOS << '_';
           arg2Type->print(nameOS);
 
           mlir::func::FuncOp newFunc = getOrCreateFunction(

diff  --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index e3ac9c930d299..dbd23520ef95a 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -998,3 +998,103 @@ fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
 // CHECK-NOT: call{{.*}}_FortranASumInteger8(
 // CHECK: call @_FortranASumInteger8x2_simplified(
 // CHECK-NOT: call{{.*}}_FortranASumInteger8(
+
+// -----
+
+func.func @dot_f32_contract_reassoc(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f32 {
+  %0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
+  %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref<!fir.char<1,10>>
+  %c3_i32 = arith.constant 3 : i32
+  %2 = fir.convert %arg0 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  %3 = fir.convert %arg1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  %4 = fir.convert %1 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
+  %5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath<contract,reassoc> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
+  fir.store %5 to %0 : !fir.ref<f32>
+  %6 = fir.load %0 : !fir.ref<f32>
+  return %6 : f32
+}
+
+func.func @dot_f32_fast(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f32 {
+  %0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
+  %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref<!fir.char<1,10>>
+  %c3_i32 = arith.constant 3 : i32
+  %2 = fir.convert %arg0 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  %3 = fir.convert %arg1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  %4 = fir.convert %1 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
+  %5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath<fast> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
+  fir.store %5 to %0 : !fir.ref<f32>
+  %6 = fir.load %0 : !fir.ref<f32>
+  return %6 : f32
+}
+
+func.func private @_FortranADotProductReal4(!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
+  %0 = fir.string_lit "./dot.f90\00"(10) : !fir.char<1,10>
+  fir.has_value %0 : !fir.char<1,10>
+}
+
+// CHECK-LABEL: @dot_f32_contract_reassoc
+// CHECK: fir.call @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified(%2, %3) fastmath<reassoc,contract>
+// CHECK-LABEL: @dot_f32_fast
+// CHECK: fir.call @_FortranADotProductReal4_fast_f32_f32_simplified(%2, %3) fastmath<fast>
+// CHECK-LABEL: func.func private @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified
+// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f32
+// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f32
+// CHECK-LABEL: func.func private @_FortranADotProductReal4_fast_f32_f32_simplified
+// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath<fast> : f32
+// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<fast> : f32
+
+// -----
+
+func.func @sum_1d_real_contract_reassoc(%arg0: !fir.ref<!fir.array<10xf64>> {fir.bindc_name = "a"}) -> f64 {
+  %c10 = arith.constant 10 : index
+  %0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"}
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
+  %3 = fir.absent !fir.box<i1>
+  %c0 = arith.constant 0 : index
+  %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref<!fir.char<1,13>>
+  %c5_i32 = arith.constant 5 : i32
+  %5 = fir.convert %2 : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
+  %6 = fir.convert %4 : (!fir.ref<!fir.char<1,13>>) -> !fir.ref<i8>
+  %7 = fir.convert %c0 : (index) -> i32
+  %8 = fir.convert %3 : (!fir.box<i1>) -> !fir.box<none>
+  %9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath<contract,reassoc> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64
+  fir.store %9 to %0 : !fir.ref<f64>
+  %10 = fir.load %0 : !fir.ref<f64>
+  return %10 : f64
+}
+
+func.func @sum_1d_real_fast(%arg0: !fir.ref<!fir.array<10xf64>> {fir.bindc_name = "a"}) -> f64 {
+  %c10 = arith.constant 10 : index
+  %0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"}
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = fir.embox %arg0(%1) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
+  %3 = fir.absent !fir.box<i1>
+  %c0 = arith.constant 0 : index
+  %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref<!fir.char<1,13>>
+  %c5_i32 = arith.constant 5 : i32
+  %5 = fir.convert %2 : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
+  %6 = fir.convert %4 : (!fir.ref<!fir.char<1,13>>) -> !fir.ref<i8>
+  %7 = fir.convert %c0 : (index) -> i32
+  %8 = fir.convert %3 : (!fir.box<i1>) -> !fir.box<none>
+  %9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath<fast> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64
+  fir.store %9 to %0 : !fir.ref<f64>
+  %10 = fir.load %0 : !fir.ref<f64>
+  return %10 : f64
+}
+
+func.func private @_FortranASumReal8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> f64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> {
+  %0 = fir.string_lit "./isum_5.f90\00"(13) : !fir.char<1,13>
+  fir.has_value %0 : !fir.char<1,13>
+}
+
+// CHECK-LABEL: @sum_1d_real_contract_reassoc
+// CHECK: fir.call @_FortranASumReal8x1_reassoc_contract_simplified(%5) fastmath<reassoc,contract>
+// CHECK-LABEL: @sum_1d_real_fast
+// CHECK: fir.call @_FortranASumReal8x1_fast_simplified(%5) fastmath<fast>
+// CHECK-LABEL: func.func private @_FortranASumReal8x1_reassoc_contract_simplified
+// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<reassoc,contract> : f64
+// CHECK-LABEL: func.func private @_FortranASumReal8x1_fast_simplified
+// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath<fast> : f64


        


More information about the flang-commits mailing list