[flang-commits] [flang] [Flang] Add new ConvertComplexPow pass for Flang (PR #158642)
Akash Banerjee via flang-commits
flang-commits at lists.llvm.org
Mon Sep 15 10:20:23 PDT 2025
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158642
>From bcf4b5ada40dbb0d764eacb42047a31b16b6c89d Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Sat, 13 Sep 2025 01:39:27 +0100
Subject: [PATCH 1/2] Force lowering to complex.pow ops.
---
.../flang/Optimizer/Transforms/Passes.td | 11 ++
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 30 ++---
flang/lib/Optimizer/Passes/Pipelines.cpp | 1 +
flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 +
.../Transforms/ConvertComplexPow.cpp | 125 ++++++++++++++++++
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 2 +
.../test/Driver/mlir-debug-pass-pipeline.f90 | 2 +
flang/test/Driver/mlir-pass-pipeline.f90 | 2 +
flang/test/Fir/basic-program.fir | 2 +
flang/test/Lower/HLFIR/binary-ops.f90 | 4 +-
flang/test/Lower/Intrinsics/pow_complex16.f90 | 5 +-
.../test/Lower/Intrinsics/pow_complex16i.f90 | 5 +-
.../test/Lower/Intrinsics/pow_complex16k.f90 | 5 +-
flang/test/Lower/power-operator.f90 | 34 ++---
flang/test/Transforms/convert-complex-pow.fir | 102 ++++++++++++++
15 files changed, 293 insertions(+), 38 deletions(-)
create mode 100644 flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
create mode 100644 flang/test/Transforms/convert-complex-pow.fir
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e3001454cdf19..0ed4bb66aff0d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
+ let summary = "Convert complex.pow operations to library calls";
+ let description = [{
+ Replace `complex.pow` operations with calls to the appropriate
+ Fortran runtime or libm functions.
+ }];
+ let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
+ "mlir::complex::ComplexDialect",
+ "mlir::arith::ArithDialect"];
+}
+
def OptimizeArrayRepacking
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
let summary = "Optimizes redundant array repacking operations";
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ce1376fd209cc..466458c05dba7 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
- bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
- if (!isAMDGPU)
+ if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
-
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- mlir::Value complexExp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
- mlir::Value result =
- builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ mlir::Value exp = args[1];
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ }
+ mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
@@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
- genLibF128Call},
+ genComplexPow},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 7c2777baebef1..ddcfffc9f158f 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -225,6 +225,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(mlir::createCanonicalizerPass(config));
pm.addPass(fir::createSimplifyRegionLite());
+ pm.addPass(fir::createConvertComplexPow());
pm.addPass(mlir::createCSEPass());
if (pc.OptLevel.isOptimizingForSpeed())
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index a8812e08c1ccd..4ec16274830fe 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp
OptimizeArrayRepacking.cpp
+ ConvertComplexPow.cpp
DEPENDS
CUFAttrs
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
new file mode 100644
index 0000000000000..8b62237cf539d
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -0,0 +1,125 @@
+//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/static-multimap-view.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+class ConvertComplexPowPass
+ : public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
+ arith::ArithDialect, func::FuncDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+// Helper to declare or get a math library function.
+static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
+ StringRef name, FunctionType type) {
+ if (auto func = builder.getNamedFunction(name))
+ return func;
+ auto func = builder.createFunction(loc, name, type);
+ func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
+ func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
+ builder.getUnitAttr());
+ return func;
+}
+
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantOp>())
+ if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
+ return attr.getValue().isZero();
+ return false;
+}
+
+void ConvertComplexPowPass::runOnOperation() {
+ auto func = getOperation();
+ auto mod = func->getParentOfType<ModuleOp>();
+ if (fir::getTargetTriple(mod).isAMDGCN())
+ return;
+
+ fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
+
+ func.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+
+ Value base = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Value intExp;
+ if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
+ if (isZero(create.getImaginary())) {
+ if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
+ if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
+ intExp = conv.getValue();
+ }
+ }
+ }
+
+ func::FuncOp callee;
+ SmallVector<Value> args;
+ if (intExp) {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ args = {base, intExp};
+ } else {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ auto funcTy =
+ builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ args = {base, rhs};
+ }
+
+ auto call = fir::CallOp::create(builder, loc, callee, args);
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index f3791fe9f8dc3..30cb97e4455ee 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,6 +69,8 @@
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
+! CHECK-NEXT: 'func.func' Pipeline
+! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42a71b2d6adc3..bb6d5509c3269 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,6 +96,8 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index e85a7728fc9af..6006f6672ee72 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,6 +127,8 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a31397efb332..a2e3cda8f2325 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,6 +125,8 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
+// PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 72cd048ea3615..1fbd333db37c3 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
-! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
subroutine real_to_int_power(x, y, z)
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90
index 7467986832479..c026dd242e964 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a, b
b = a ** b
end
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 6f8684d9a663a..1827863a57f43 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index d3765050640ae..039dfd5152a06 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..3058927144248 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -1,10 +1,10 @@
-! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
-! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
-! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
! Test power operation lowering
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowi
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowk
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowi
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowk
end subroutine
! CHECK-LABEL: pow_c4_c4
subroutine pow_c4_c4(x, y, z)
complex :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
- ! PRECISE: call @cpowf
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
+ ! PRECISE: fir.call @cpowf
end subroutine
! CHECK-LABEL: pow_c8_c8
subroutine pow_c8_c8(x, y, z)
complex(8) :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
- ! PRECISE: call @cpow
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
+ ! PRECISE: fir.call @cpow
end subroutine
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
new file mode 100644
index 0000000000000..d980817aba9b9
--- /dev/null
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -0,0 +1,102 @@
+// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
+
+module {
+ func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+ %c0 = arith.constant 0.000000e+00 : f32
+ %c1 = fir.convert %arg1 : (i32) -> f32
+ %c2 = complex.create %c1, %c0 : complex<f32>
+ %0 = complex.pow %arg0, %c2 : complex<f32>
+ return %0 : complex<f32>
+ }
+
+ func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
+ %c0 = arith.constant 0.000000e+00 : f32
+ %c1 = fir.convert %arg1 : (i64) -> f32
+ %c2 = complex.create %c1, %c0 : complex<f32>
+ %0 = complex.pow %arg0, %c2 : complex<f32>
+ return %0 : complex<f32>
+ }
+
+ func.func @pow_c4_c4(%arg0: complex<f32>, %arg1: complex<f32>) -> complex<f32> {
+ %0 = complex.pow %arg0, %arg1 : complex<f32>
+ return %0 : complex<f32>
+ }
+
+ func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
+ %c0 = arith.constant 0.000000e+00 : f64
+ %c1 = fir.convert %arg1 : (i32) -> f64
+ %c2 = complex.create %c1, %c0 : complex<f64>
+ %0 = complex.pow %arg0, %c2 : complex<f64>
+ return %0 : complex<f64>
+ }
+
+ func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
+ %c0 = arith.constant 0.000000e+00 : f64
+ %c1 = fir.convert %arg1 : (i64) -> f64
+ %c2 = complex.create %c1, %c0 : complex<f64>
+ %0 = complex.pow %arg0, %c2 : complex<f64>
+ return %0 : complex<f64>
+ }
+
+ func.func @pow_c8_c8(%arg0: complex<f64>, %arg1: complex<f64>) -> complex<f64> {
+ %0 = complex.pow %arg0, %arg1 : complex<f64>
+ return %0 : complex<f64>
+ }
+
+ func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
+ %c0 = arith.constant 0.000000e+00 : f128
+ %c1 = fir.convert %arg1 : (i32) -> f128
+ %c2 = complex.create %c1, %c0 : complex<f128>
+ %0 = complex.pow %arg0, %c2 : complex<f128>
+ return %0 : complex<f128>
+ }
+
+ func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
+ %c0 = arith.constant 0.000000e+00 : f128
+ %c1 = fir.convert %arg1 : (i64) -> f128
+ %c2 = complex.create %c1, %c0 : complex<f128>
+ %0 = complex.pow %arg0, %c2 : complex<f128>
+ return %0 : complex<f128>
+ }
+
+ func.func @pow_c16_c16(%arg0: complex<f128>, %arg1: complex<f128>) -> complex<f128> {
+ %0 = complex.pow %arg0, %arg1 : complex<f128>
+ return %0 : complex<f128>
+ }
+}
+
+// CHECK-LABEL: func.func @pow_c4_i4(
+// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_i8(
+// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_c4(
+// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i4(
+// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i8(
+// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_c8(
+// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i4(
+// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i8(
+// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_c16(
+// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
+// CHECK-NOT: complex.pow
>From c8715a11d49e93240926a5b98e1f0b0e37b83f29 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Mon, 15 Sep 2025 18:19:50 +0100
Subject: [PATCH 2/2] Change ConverComplexPow from func to module pass.
---
flang/include/flang/Optimizer/Transforms/Passes.td | 2 +-
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 +++----
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 3 +--
flang/test/Driver/mlir-debug-pass-pipeline.f90 | 3 +--
flang/test/Driver/mlir-pass-pipeline.f90 | 3 +--
flang/test/Fir/basic-program.fir | 3 +--
6 files changed, 8 insertions(+), 13 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 0ed4bb66aff0d..093d5de028048 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -551,7 +551,7 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}
-def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
let summary = "Convert complex.pow operations to library calls";
let description = [{
Replace `complex.pow` operations with calls to the appropriate
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 8b62237cf539d..dced5f90d6924 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -55,14 +55,13 @@ static bool isZero(Value v) {
}
void ConvertComplexPowPass::runOnOperation() {
- auto func = getOperation();
- auto mod = func->getParentOfType<ModuleOp>();
+ ModuleOp mod = getOperation();
if (fir::getTargetTriple(mod).isAMDGCN())
return;
- fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
+ fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
- func.walk([&](complex::PowOp op) {
+ mod.walk([&](complex::PowOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 30cb97e4455ee..bf2712d547a82 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,8 +69,7 @@
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
-! CHECK-NEXT: 'func.func' Pipeline
-! CHECK-NEXT: ConvertComplexPow
+! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index bb6d5509c3269..5943a3c61c342 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,8 +96,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
-! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT: ConvertComplexPow
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 6006f6672ee72..4fd89d6f15d46 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,8 +127,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
-! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT: ConvertComplexPow
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index a2e3cda8f2325..195e5ad7f9dc8 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,8 +125,7 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
-// PASSES-NEXT: 'func.func' Pipeline
-// PASSES-NEXT: ConvertComplexPow
+// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
More information about the flang-commits
mailing list