[flang-commits] [flang] 54677d6 - [Flang] Add new ConvertComplexPow pass for Flang (#158642)
via flang-commits
flang-commits at lists.llvm.org
Thu Sep 18 17:51:14 PDT 2025
Author: Akash Banerjee
Date: 2025-09-19T01:51:10+01:00
New Revision: 54677d66c4af83351df63e513d7734e2c25160df
URL: https://github.com/llvm/llvm-project/commit/54677d66c4af83351df63e513d7734e2c25160df
DIFF: https://github.com/llvm/llvm-project/commit/54677d66c4af83351df63e513d7734e2c25160df.diff
LOG: [Flang] Add new ConvertComplexPow pass for Flang (#158642)
This PR introduces a new `ConvertComplexPow` pass for Flang that handles
complex power operations. The change forces lowering to complex.pow
operations when `--math-runtime=precise` is not used, then uses the
`ConvertComplexPow` pass to convert these operations back to library
calls.
- Adds a new `ConvertComplexPow` pass that converts complex.pow ops to
appropriate runtime library calls
- Updates complex power lowering to use `complex.pow` operations by
default instead of direct library calls
#158722 Adds a new `complex.powi` op enabling algebraic optimisations.
Added:
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
flang/test/Transforms/convert-complex-pow.fir
Modified:
flang/include/flang/Optimizer/Transforms/Passes.td
flang/include/flang/Tools/CrossToolHelpers.h
flang/lib/Frontend/FrontendActions.cpp
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/lib/Optimizer/Passes/Pipelines.cpp
flang/lib/Optimizer/Transforms/CMakeLists.txt
flang/test/Driver/bbc-mlir-pass-pipeline.f90
flang/test/Driver/mlir-debug-pass-pipeline.f90
flang/test/Driver/mlir-pass-pipeline.f90
flang/test/Fir/basic-program.fir
flang/test/Lower/HLFIR/binary-ops.f90
flang/test/Lower/Intrinsics/pow_complex16.f90
flang/test/Lower/Intrinsics/pow_complex16i.f90
flang/test/Lower/Intrinsics/pow_complex16k.f90
flang/test/Lower/power-operator.f90
flang/tools/bbc/bbc.cpp
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index b7fa0ca5f5719..88573fa9dff7d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -555,6 +555,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
+ let summary = "Convert complex.pow operations to library calls";
+ let description = [{
+ Replace `complex.pow` operations with calls to the appropriate
+ Fortran runtime or libm functions.
+ }];
+ let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
+ "mlir::complex::ComplexDialect",
+ "mlir::arith::ArithDialect"];
+}
+
def OptimizeArrayRepacking
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
let summary = "Optimizes redundant array repacking operations";
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index 038f388f2ec0b..01c34eee014f3 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -135,6 +135,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments.
bool EnableOpenMP = false; ///< Enable OpenMP lowering.
bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode.
+ bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion.
std::string InstrumentFunctionEntry =
""; ///< Name of the instrument-function that is called on each
///< function-entry
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 3288908785276..c3c53d51015a2 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() {
pm.enableVerifier(/*verifyPasses=*/true);
MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
+ llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
+ config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
fir::registerDefaultInlinerPass(config);
if (auto vsr = getVScaleRange(ci)) {
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ce1376fd209cc..466458c05dba7 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
- bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
- if (!isAMDGPU)
+ if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
-
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- mlir::Value complexExp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
- mlir::Value result =
- builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ mlir::Value exp = args[1];
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ }
+ mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
@@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
- genLibF128Call},
+ genComplexPow},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 58f60d43b1d49..fd7d521722a42 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -226,6 +226,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(mlir::createCanonicalizerPass(config));
pm.addPass(fir::createSimplifyRegionLite());
+ if (!pc.SkipConvertComplexPow)
+ pm.addPass(fir::createConvertComplexPow());
pm.addPass(mlir::createCSEPass());
if (pc.OptLevel.isOptimizingForSpeed())
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index a8812e08c1ccd..4ec16274830fe 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp
OptimizeArrayRepacking.cpp
+ ConvertComplexPow.cpp
DEPENDS
CUFAttrs
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
new file mode 100644
index 0000000000000..78f9d9e4f639a
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -0,0 +1,123 @@
+//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/static-multimap-view.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+class ConvertComplexPowPass
+ : public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
+ arith::ArithDialect, func::FuncDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+// Helper to declare or get a math library function.
+static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
+ StringRef name, FunctionType type) {
+ if (auto func = builder.getNamedFunction(name))
+ return func;
+ auto func = builder.createFunction(loc, name, type);
+ func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
+ func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
+ builder.getUnitAttr());
+ return func;
+}
+
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantOp>())
+ if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
+ return attr.getValue().isZero();
+ return false;
+}
+
+void ConvertComplexPowPass::runOnOperation() {
+ ModuleOp mod = getOperation();
+ fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
+
+ mod.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+
+ Value base = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Value intExp;
+ if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
+ if (isZero(create.getImaginary())) {
+ if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
+ if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
+ intExp = conv.getValue();
+ }
+ }
+ }
+
+ func::FuncOp callee;
+ SmallVector<Value> args;
+ if (intExp) {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ args = {base, intExp};
+ } else {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ auto funcTy =
+ builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ args = {base, rhs};
+ }
+
+ auto call = fir::CallOp::create(builder, loc, callee, args);
+ if (auto fmf = op.getFastmathAttr())
+ call.setFastmathAttr(fmf);
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index f3791fe9f8dc3..bf2712d547a82 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,6 +69,7 @@
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
+! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42a71b2d6adc3..5943a3c61c342 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,6 +96,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index e85a7728fc9af..4fd89d6f15d46 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,6 +127,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a31397efb332..195e5ad7f9dc8 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,6 +125,7 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
+// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 72cd048ea3615..1fbd333db37c3 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
-! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
subroutine real_to_int_power(x, y, z)
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90
index 7467986832479..c026dd242e964 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a, b
b = a ** b
end
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 6f8684d9a663a..1827863a57f43 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index d3765050640ae..039dfd5152a06 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..3058927144248 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -1,10 +1,10 @@
-! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
-! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
-! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
! Test power operation lowering
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowi
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowk
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowi
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowk
end subroutine
! CHECK-LABEL: pow_c4_c4
subroutine pow_c4_c4(x, y, z)
complex :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
- ! PRECISE: call @cpowf
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
+ ! PRECISE: fir.call @cpowf
end subroutine
! CHECK-LABEL: pow_c8_c8
subroutine pow_c8_c8(x, y, z)
complex(8) :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
- ! PRECISE: call @cpow
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
+ ! PRECISE: fir.call @cpow
end subroutine
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
new file mode 100644
index 0000000000000..e09fa7316c4b0
--- /dev/null
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -0,0 +1,111 @@
+// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
+
+module {
+ func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = fir.convert %arg1 : (i32) -> f32
+ %1 = complex.create %0, %c0 : complex<f32>
+ %2 = complex.pow %arg0, %1 : complex<f32>
+ return %2 : complex<f32>
+ }
+
+ func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = fir.convert %arg1 : (i64) -> f32
+ %1 = complex.create %0, %c0 : complex<f32>
+ %2 = complex.pow %arg0, %1 : complex<f32>
+ return %2 : complex<f32>
+ }
+
+ func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
+ %c0 = arith.constant 0.0 : f64
+ %0 = fir.convert %arg1 : (i32) -> f64
+ %1 = complex.create %0, %c0 : complex<f64>
+ %2 = complex.pow %arg0, %1 : complex<f64>
+ return %2 : complex<f64>
+ }
+
+ func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
+ %c0 = arith.constant 0.0 : f64
+ %0 = fir.convert %arg1 : (i64) -> f64
+ %1 = complex.create %0, %c0 : complex<f64>
+ %2 = complex.pow %arg0, %1 : complex<f64>
+ return %2 : complex<f64>
+ }
+
+ func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
+ %c0 = arith.constant 0.0 : f128
+ %0 = fir.convert %arg1 : (i32) -> f128
+ %1 = complex.create %0, %c0 : complex<f128>
+ %2 = complex.pow %arg0, %1 : complex<f128>
+ return %2 : complex<f128>
+ }
+
+ func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
+ %c0 = arith.constant 0.0 : f128
+ %0 = fir.convert %arg1 : (i64) -> f128
+ %1 = complex.create %0, %c0 : complex<f128>
+ %2 = complex.pow %arg0, %1 : complex<f128>
+ return %2 : complex<f128>
+ }
+
+ func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
+ %c1 = arith.constant 1.0 : f32
+ %0 = complex.create %arg1, %c1 : complex<f32>
+ %1 = complex.pow %arg0, %0 fastmath<fast> : complex<f32>
+ return %1 : complex<f32>
+ }
+
+ func.func @pow_c8_complex(%arg0: complex<f64>, %arg1: f64) -> complex<f64> {
+ %c2 = arith.constant 2.0 : f64
+ %0 = complex.create %arg1, %c2 : complex<f64>
+ %1 = complex.pow %arg0, %0 : complex<f64>
+ return %1 : complex<f64>
+ }
+
+ func.func @pow_c16_complex(%arg0: complex<f128>, %arg1: f128) -> complex<f128> {
+ %c3 = arith.constant 3.0 : f128
+ %0 = complex.create %arg1, %c3 : complex<f128>
+ %1 = complex.pow %arg0, %0 : complex<f128>
+ return %1 : complex<f128>
+ }
+}
+
+// CHECK-LABEL: func.func @pow_c4_i4(
+// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_i8(
+// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i4(
+// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i8(
+// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i4(
+// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i8(
+// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_fast(
+// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
+// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath<fast> : (complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_complex(
+// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f64>
+// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_complex(
+// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
+// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
+// CHECK-NOT: complex.pow
\ No newline at end of file
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 82dff2653ad09..69a45c66a079a 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -538,6 +538,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
// Add O2 optimizer pass pipeline.
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+ config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN();
if (enableOpenMP)
config.EnableOpenMP = true;
config.NSWOnLoopVarInc = !integerWrapAround;
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..72b1fa6e833f9 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
LogicalResult matchAndRewrite(complex::PowOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
- Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
- Value exp = complex::ExpOp::create(rewriter, loc, mul);
+ auto fastmath = op.getFastmathAttr();
+ Value logBase =
+ complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath);
+ Value mul =
+ complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath);
+ Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath);
rewriter.replaceOp(op, exp);
return success();
}
More information about the flang-commits
mailing list