[flang-commits] [flang] [Flang] Add new ConvertComplexPow pass for Flang (PR #158642)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Mon Sep 15 10:20:23 PDT 2025


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158642

>From bcf4b5ada40dbb0d764eacb42047a31b16b6c89d Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Sat, 13 Sep 2025 01:39:27 +0100
Subject: [PATCH 1/2] Force lowering to complex.pow ops.

---
 .../flang/Optimizer/Transforms/Passes.td      |  11 ++
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp |  30 ++---
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   1 +
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Transforms/ConvertComplexPow.cpp          | 125 ++++++++++++++++++
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   2 +
 .../test/Driver/mlir-debug-pass-pipeline.f90  |   2 +
 flang/test/Driver/mlir-pass-pipeline.f90      |   2 +
 flang/test/Fir/basic-program.fir              |   2 +
 flang/test/Lower/HLFIR/binary-ops.f90         |   4 +-
 flang/test/Lower/Intrinsics/pow_complex16.f90 |   5 +-
 .../test/Lower/Intrinsics/pow_complex16i.f90  |   5 +-
 .../test/Lower/Intrinsics/pow_complex16k.f90  |   5 +-
 flang/test/Lower/power-operator.f90           |  34 ++---
 flang/test/Transforms/convert-complex-pow.fir | 102 ++++++++++++++
 15 files changed, 293 insertions(+), 38 deletions(-)
 create mode 100644 flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
 create mode 100644 flang/test/Transforms/convert-complex-pow.fir

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e3001454cdf19..0ed4bb66aff0d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
       "Prefer expanding without using Fortran runtime calls.">];
 }
 
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
+  let summary = "Convert complex.pow operations to library calls";
+  let description = [{
+    Replace `complex.pow` operations with calls to the appropriate
+    Fortran runtime or libm functions.
+  }];
+  let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
+                           "mlir::complex::ComplexDialect",
+                           "mlir::arith::ArithDialect"];
+}
+
 def OptimizeArrayRepacking
     : Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
   let summary = "Optimizes redundant array repacking operations";
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ce1376fd209cc..466458c05dba7 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
                           const MathOperation &mathOp,
                           mlir::FunctionType mathLibFuncType,
                           llvm::ArrayRef<mlir::Value> args) {
-  bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
-  if (!isAMDGPU)
+  if (mathRuntimeVersion == preciseVersion)
     return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
-
   auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
-  auto realTy = complexTy.getElementType();
-  mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
-  mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
-  mlir::Value complexExp =
-      builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
-  mlir::Value result =
-      builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+  mlir::Value exp = args[1];
+  if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+    auto realTy = complexTy.getElementType();
+    mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+    mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+    exp =
+        builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  }
+  mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
   return result;
 }
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
     {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
     {"pow", "cpowf",
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
-     genComplexMathOp<mlir::complex::PowOp>},
+     genComplexPow},
     {"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
-     genComplexMathOp<mlir::complex::PowOp>},
+     genComplexPow},
     {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
-     genLibF128Call},
+     genComplexPow},
     {"pow", RTNAME_STRING(FPow4i),
      genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
      genMathOp<mlir::math::FPowIOp>},
@@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
      genComplexPow},
     {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
-     genLibF128Call},
+     genComplexPow},
     {"pow", RTNAME_STRING(cpowk),
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
      genComplexPow},
@@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
      genComplexPow},
     {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
-     genLibF128Call},
+     genComplexPow},
     {"pow-unsigned", RTNAME_STRING(UPow1),
      genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
     {"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 7c2777baebef1..ddcfffc9f158f 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -225,6 +225,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
 
   pm.addPass(mlir::createCanonicalizerPass(config));
   pm.addPass(fir::createSimplifyRegionLite());
+  pm.addPass(fir::createConvertComplexPow());
   pm.addPass(mlir::createCSEPass());
 
   if (pc.OptLevel.isOptimizingForSpeed())
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index a8812e08c1ccd..4ec16274830fe 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
   GenRuntimeCallsForTest.cpp
   SimplifyFIROperations.cpp
   OptimizeArrayRepacking.cpp
+  ConvertComplexPow.cpp
 
   DEPENDS
   CUFAttrs
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
new file mode 100644
index 0000000000000..8b62237cf539d
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -0,0 +1,125 @@
+//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/static-multimap-view.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+class ConvertComplexPowPass
+    : public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
+                    arith::ArithDialect, func::FuncDialect>();
+  }
+  void runOnOperation() override;
+};
+} // namespace
+
+// Helper to declare or get a math library function.
+static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
+                                 StringRef name, FunctionType type) {
+  if (auto func = builder.getNamedFunction(name))
+    return func;
+  auto func = builder.createFunction(loc, name, type);
+  func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
+  func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
+                builder.getUnitAttr());
+  return func;
+}
+
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
+    if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
+      return attr.getValue().isZero();
+  return false;
+}
+
+void ConvertComplexPowPass::runOnOperation() {
+  auto func = getOperation();
+  auto mod = func->getParentOfType<ModuleOp>();
+  if (fir::getTargetTriple(mod).isAMDGCN())
+    return;
+
+  fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
+
+  func.walk([&](complex::PowOp op) {
+    builder.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto complexTy = cast<ComplexType>(op.getType());
+    auto elemTy = complexTy.getElementType();
+
+    Value base = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Value intExp;
+    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
+      if (isZero(create.getImaginary())) {
+        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
+          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
+            intExp = conv.getValue();
+        }
+      }
+    }
+
+    func::FuncOp callee;
+    SmallVector<Value> args;
+    if (intExp) {
+      unsigned realBits = cast<FloatType>(elemTy).getWidth();
+      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+      auto funcTy = builder.getFunctionType(
+          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+      if (realBits == 32 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+      else if (realBits == 32 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+      else if (realBits == 64 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+      else if (realBits == 64 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+      else if (realBits == 128 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+      else if (realBits == 128 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+      else
+        return;
+      args = {base, intExp};
+    } else {
+      unsigned realBits = cast<FloatType>(elemTy).getWidth();
+      auto funcTy =
+          builder.getFunctionType({complexTy, complexTy}, {complexTy});
+      if (realBits == 32)
+        callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+      else if (realBits == 64)
+        callee = getOrDeclare(builder, loc, "cpow", funcTy);
+      else if (realBits == 128)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+      else
+        return;
+      args = {base, rhs};
+    }
+
+    auto call = fir::CallOp::create(builder, loc, callee, args);
+    op.replaceAllUsesWith(call.getResult(0));
+    op.erase();
+  });
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index f3791fe9f8dc3..30cb97e4455ee 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,6 +69,8 @@
 ! CHECK-NEXT: SCFToControlFlow
 ! CHECK-NEXT: Canonicalizer
 ! CHECK-NEXT: SimplifyRegionLite
+! CHECK-NEXT: 'func.func' Pipeline
+! CHECK-NEXT:   ConvertComplexPow
 ! CHECK-NEXT: CSE
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42a71b2d6adc3..bb6d5509c3269 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,6 +96,8 @@
 ! ALL-NEXT: SCFToControlFlow
 ! ALL-NEXT: Canonicalizer
 ! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT:   ConvertComplexPow
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index e85a7728fc9af..6006f6672ee72 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,6 +127,8 @@
 ! ALL-NEXT: SCFToControlFlow
 ! ALL-NEXT: Canonicalizer
 ! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT:   ConvertComplexPow
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a31397efb332..a2e3cda8f2325 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,6 +125,8 @@ func.func @_QQmain() {
 // PASSES-NEXT: SCFToControlFlow
 // PASSES-NEXT: Canonicalizer
 // PASSES-NEXT: SimplifyRegionLite
+// PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT:   ConvertComplexPow
 // PASSES-NEXT: CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 72cd048ea3615..1fbd333db37c3 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
-! CHECK:  %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
+! CHECK:  %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
 
 
 subroutine real_to_int_power(x, y, z)
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK:  %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! CHECK:  %[[VAL_8:.*]] = complex.pow
 
 subroutine extremum(c, n, l)
   integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90
index 7467986832479..c026dd242e964 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16.f90
@@ -1,9 +1,10 @@
 ! REQUIRES: flang-supports-f128-math
 ! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
-! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
   complex(16) :: a, b
   b = a ** b
 end
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 6f8684d9a663a..1827863a57f43 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -1,9 +1,10 @@
 ! REQUIRES: flang-supports-f128-math
 ! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
-! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
   complex(16) :: a
   integer(4) :: b
   b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index d3765050640ae..039dfd5152a06 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -1,9 +1,10 @@
 ! REQUIRES: flang-supports-f128-math
 ! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
-! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
   complex(16) :: a
   integer(8) :: b
   b = a ** b
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..3058927144248 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -1,10 +1,10 @@
-! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
-! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
-! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
 
 ! Test power operation lowering
 
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: call @_FortranAcpowi
+  ! CHECK: complex.pow
+  ! PRECISE: fir.call @_FortranAcpowi
 end subroutine
 
 ! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: call @_FortranAcpowk
+  ! CHECK: complex.pow
+  ! PRECISE: fir.call @_FortranAcpowk
 end subroutine
 
 ! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: call @_FortranAzpowi
+  ! CHECK: complex.pow
+  ! PRECISE: fir.call @_FortranAzpowi
 end subroutine
 
 ! CHECK-LABEL: pow_c8_i8
@@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: call @_FortranAzpowk
+  ! CHECK: complex.pow
+  ! PRECISE: fir.call @_FortranAzpowk
 end subroutine
 
 ! CHECK-LABEL: pow_c4_c4
 subroutine pow_c4_c4(x, y, z)
   complex :: x, y, z
   z = x ** y
-  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
-  ! PRECISE: call @cpowf
+  ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
+  ! PRECISE: fir.call @cpowf
 end subroutine
 
 ! CHECK-LABEL: pow_c8_c8
 subroutine pow_c8_c8(x, y, z)
   complex(8) :: x, y, z
   z = x ** y
-  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
-  ! PRECISE: call @cpow
+  ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
+  ! PRECISE: fir.call @cpow
 end subroutine
 
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
new file mode 100644
index 0000000000000..d980817aba9b9
--- /dev/null
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -0,0 +1,102 @@
+// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
+
+module {
+  func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+    %c0 = arith.constant 0.000000e+00 : f32
+    %c1 = fir.convert %arg1 : (i32) -> f32
+    %c2 = complex.create %c1, %c0 : complex<f32>
+    %0 = complex.pow %arg0, %c2 : complex<f32>
+    return %0 : complex<f32>
+  }
+
+  func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
+    %c0 = arith.constant 0.000000e+00 : f32
+    %c1 = fir.convert %arg1 : (i64) -> f32
+    %c2 = complex.create %c1, %c0 : complex<f32>
+    %0 = complex.pow %arg0, %c2 : complex<f32>
+    return %0 : complex<f32>
+  }
+
+  func.func @pow_c4_c4(%arg0: complex<f32>, %arg1: complex<f32>) -> complex<f32> {
+    %0 = complex.pow %arg0, %arg1 : complex<f32>
+    return %0 : complex<f32>
+  }
+
+  func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
+    %c0 = arith.constant 0.000000e+00 : f64
+    %c1 = fir.convert %arg1 : (i32) -> f64
+    %c2 = complex.create %c1, %c0 : complex<f64>
+    %0 = complex.pow %arg0, %c2 : complex<f64>
+    return %0 : complex<f64>
+  }
+
+  func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
+    %c0 = arith.constant 0.000000e+00 : f64
+    %c1 = fir.convert %arg1 : (i64) -> f64
+    %c2 = complex.create %c1, %c0 : complex<f64>
+    %0 = complex.pow %arg0, %c2 : complex<f64>
+    return %0 : complex<f64>
+  }
+
+  func.func @pow_c8_c8(%arg0: complex<f64>, %arg1: complex<f64>) -> complex<f64> {
+    %0 = complex.pow %arg0, %arg1 : complex<f64>
+    return %0 : complex<f64>
+  }
+
+  func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
+    %c0 = arith.constant 0.000000e+00 : f128
+    %c1 = fir.convert %arg1 : (i32) -> f128
+    %c2 = complex.create %c1, %c0 : complex<f128>
+    %0 = complex.pow %arg0, %c2 : complex<f128>
+    return %0 : complex<f128>
+  }
+
+  func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
+    %c0 = arith.constant 0.000000e+00 : f128
+    %c1 = fir.convert %arg1 : (i64) -> f128
+    %c2 = complex.create %c1, %c0 : complex<f128>
+    %0 = complex.pow %arg0, %c2 : complex<f128>
+    return %0 : complex<f128>
+  }
+
+  func.func @pow_c16_c16(%arg0: complex<f128>, %arg1: complex<f128>) -> complex<f128> {
+    %0 = complex.pow %arg0, %arg1 : complex<f128>
+    return %0 : complex<f128>
+  }
+}
+
+// CHECK-LABEL: func.func @pow_c4_i4(
+// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_i8(
+// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c4_c4(
+// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i4(
+// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_i8(
+// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c8_c8(
+// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i4(
+// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_i8(
+// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
+// CHECK-NOT: complex.pow
+
+// CHECK-LABEL: func.func @pow_c16_c16(
+// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
+// CHECK-NOT: complex.pow

>From c8715a11d49e93240926a5b98e1f0b0e37b83f29 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Mon, 15 Sep 2025 18:19:50 +0100
Subject: [PATCH 2/2] Change ConverComplexPow from func to module pass.

---
 flang/include/flang/Optimizer/Transforms/Passes.td   | 2 +-
 flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 +++----
 flang/test/Driver/bbc-mlir-pass-pipeline.f90         | 3 +--
 flang/test/Driver/mlir-debug-pass-pipeline.f90       | 3 +--
 flang/test/Driver/mlir-pass-pipeline.f90             | 3 +--
 flang/test/Fir/basic-program.fir                     | 3 +--
 6 files changed, 8 insertions(+), 13 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 0ed4bb66aff0d..093d5de028048 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -551,7 +551,7 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
       "Prefer expanding without using Fortran runtime calls.">];
 }
 
-def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
   let summary = "Convert complex.pow operations to library calls";
   let description = [{
     Replace `complex.pow` operations with calls to the appropriate
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 8b62237cf539d..dced5f90d6924 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -55,14 +55,13 @@ static bool isZero(Value v) {
 }
 
 void ConvertComplexPowPass::runOnOperation() {
-  auto func = getOperation();
-  auto mod = func->getParentOfType<ModuleOp>();
+  ModuleOp mod = getOperation();
   if (fir::getTargetTriple(mod).isAMDGCN())
     return;
 
-  fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
+  fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  func.walk([&](complex::PowOp op) {
+  mod.walk([&](complex::PowOp op) {
     builder.setInsertionPoint(op);
     Location loc = op.getLoc();
     auto complexTy = cast<ComplexType>(op.getType());
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 30cb97e4455ee..bf2712d547a82 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,8 +69,7 @@
 ! CHECK-NEXT: SCFToControlFlow
 ! CHECK-NEXT: Canonicalizer
 ! CHECK-NEXT: SimplifyRegionLite
-! CHECK-NEXT: 'func.func' Pipeline
-! CHECK-NEXT:   ConvertComplexPow
+! CHECK-NEXT: ConvertComplexPow
 ! CHECK-NEXT: CSE
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index bb6d5509c3269..5943a3c61c342 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,8 +96,7 @@
 ! ALL-NEXT: SCFToControlFlow
 ! ALL-NEXT: Canonicalizer
 ! ALL-NEXT: SimplifyRegionLite
-! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT:   ConvertComplexPow
+! ALL-NEXT: ConvertComplexPow
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 6006f6672ee72..4fd89d6f15d46 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,8 +127,7 @@
 ! ALL-NEXT: SCFToControlFlow
 ! ALL-NEXT: Canonicalizer
 ! ALL-NEXT: SimplifyRegionLite
-! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT:   ConvertComplexPow
+! ALL-NEXT: ConvertComplexPow
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index a2e3cda8f2325..195e5ad7f9dc8 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,8 +125,7 @@ func.func @_QQmain() {
 // PASSES-NEXT: SCFToControlFlow
 // PASSES-NEXT: Canonicalizer
 // PASSES-NEXT: SimplifyRegionLite
-// PASSES-NEXT: 'func.func' Pipeline
-// PASSES-NEXT:   ConvertComplexPow
+// PASSES-NEXT: ConvertComplexPow
 // PASSES-NEXT: CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd



More information about the flang-commits mailing list