[flang-commits] [flang] [mlir] [MLIR] Add cpow support in ComplexToROCDLLibraryCalls (PR #153183)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Thu Aug 14 05:13:18 PDT 2025


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/153183

>From 1be3c7422c2539af0e3b796becbf83cc2e28dbf7 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 12 Aug 2025 14:06:38 +0100
Subject: [PATCH 1/3] [MLIR] Add cpow support in ComplexToROCDLLibraryCalls

This PR contributes the following changes:
	1. Force lowering to complex.pow ops for the amdgcn-amd-amdhsa target.
	2. Convert complex.pow(z, w) -> complex.exp(w * complex.log(z)).
	3. Convert x ** 2 -> x * x, x ** 3 -> x * x * x, ... x ** 8 -> x * x... .
---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 34 +++++++++++++++--
 flang/test/Lower/amdgcn-complex.f90           | 22 +++++++----
 flang/test/Lower/power-operator.f90           | 12 ++++--
 .../ComplexToROCDLLibraryCalls.cpp            | 38 ++++++++++++++++++-
 .../complex-to-rocdl-library-calls.mlir       | 27 +++++++++++++
 5 files changed, 115 insertions(+), 18 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 3e6fbafe8a6b3..2f8965adfb320 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
   return result;
 }
 
+mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
+                           const MathOperation &mathOp,
+                           mlir::FunctionType mathLibFuncType,
+                           llvm::ArrayRef<mlir::Value> args) {
+  bool canUseApprox = mlir::arith::bitEnumContainsAny(
+      builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
+  bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+  if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
+    return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+  auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+  auto realTy = complexTy.getElementType();
+  mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+  mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+  mlir::Value complexExp =
+      builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  mlir::Value result =
+      builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+  result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+  return result;
+}
+
 /// Mapping between mathematical intrinsic operations and MLIR operations
 /// of some appropriate dialect (math, complex, etc.) or libm calls.
 /// TODO: support remaining Fortran math intrinsics.
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = {
      genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
      genMathOp<mlir::math::FPowIOp>},
     {"pow", RTNAME_STRING(cpowi),
-     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+     genComplexPowI},
     {"pow", RTNAME_STRING(zpowi),
-     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+     genComplexPowI},
     {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
      genLibF128Call},
     {"pow", RTNAME_STRING(cpowk),
-     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+     genComplexPowI},
     {"pow", RTNAME_STRING(zpowk),
-     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+     genComplexPowI},
     {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
      genLibF128Call},
     {"remainder", "remainderf",
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index f15c7db2b7316..3d52355d3d50a 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -1,21 +1,27 @@
 ! REQUIRES: amdgpu-registered-target
-! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
 
+! CHECK-LABEL: func @_QPcabsf_test(
+! CHECK: complex.abs
+! CHECK-NOT: fir.call @cabsf
 subroutine cabsf_test(a, b)
    complex :: a
    real :: b
    b = abs(a)
 end subroutine
 
-! CHECK-LABEL: func @_QPcabsf_test(
-! CHECK: complex.abs
-! CHECK-NOT: fir.call @cabsf
-
+! CHECK-LABEL: func @_QPcexpf_test(
+! CHECK: complex.exp
+! CHECK-NOT: fir.call @cexpf
 subroutine cexpf_test(a, b)
    complex :: a, b
    b = exp(a)
 end subroutine
 
-! CHECK-LABEL: func @_QPcexpf_test(
-! CHECK: complex.exp
-! CHECK-NOT: fir.call @cexpf
+! CHECK-LABEL: func @_QPpow_test(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test(a, b)
+   complex :: a, b
+   a = b**2
+end subroutine pow_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..2a0a09e090dde 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: call @_FortranAcpowi
+  ! PRECISE: call @_FortranAcpowi
+  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
 end subroutine
 
 ! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: call @_FortranAcpowk
+  ! PRECISE: call @_FortranAcpowk
+  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
 end subroutine
 
 ! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: call @_FortranAzpowi
+  ! PRECISE: call @_FortranAzpowi
+  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
 end subroutine
 
 ! CHECK-LABEL: pow_c8_i8
@@ -120,7 +123,8 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: call @_FortranAzpowk
+  ! PRECISE: call @_FortranAzpowk
+  ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
 end subroutine
 
 ! CHECK-LABEL: pow_c4_c4
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index b3d6d59e25bd0..558fcdf782800 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,10 +56,43 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
 private:
   std::string funcName;
 };
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+  using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowOp op,
+                                PatternRewriter &rewriter) const final {
+    auto loc = op.getLoc();
+    if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
+      ArrayAttr value = constOp.getValue();
+      if (value.size() == 2) {
+        auto real = dyn_cast<FloatAttr>(value[0]);
+        auto imag = dyn_cast<FloatAttr>(value[1]);
+        if (real && imag && imag.getValue().isZero())
+          for (int i = 2; i <= 8; ++i)
+            if (real.getValue().isExactlyValue(i)) {
+              Value base = op.getLhs();
+              Value result = base;
+              for (int j = 1; j < i; ++j)
+                result = rewriter.create<complex::MulOp>(loc, result, base);
+              rewriter.replaceOp(op, result);
+              return success();
+            }
+      }
+    }
+    Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+    Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+    Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+    rewriter.replaceOp(op, exp);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +143,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
+  target.addLegalOp<complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::SinOp, complex::SqrtOp,
-                      complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::SinOp,
+                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 82936d89e8ac1..ef6ae74a45c1c 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -57,6 +57,33 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
   return %lf, %ld : complex<f32>, complex<f64>
 }
 
+//CHECK-LABEL: @pow_caller
+//CHECK:          (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
+func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
+  // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
+  // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
+  // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+  // CHECK: return %[[EXP]]
+  %r = complex.pow %z, %w : complex<f32>
+  return %r : complex<f32>
+}
+
+// CHECK-LABEL: @pow_int_caller
+func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
+    ->(complex<f32>, complex<f64>) {
+  // CHECK-NOT: call @__ocml
+  // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+  %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+  %p2 = complex.pow %f, %c2 : complex<f32>
+  // CHECK-NOT: call @__ocml
+  // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
+  // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
+  %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+  %p3 = complex.pow %d, %c3 : complex<f64>
+  // CHECK: return %[[M2]], %[[M3B]]
+  return %p2, %p3 : complex<f32>, complex<f64>
+}
+
 //CHECK-LABEL: @sin_caller
 func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})

>From 720f065a74263f47e5990933771bde7eed193cbe Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 12 Aug 2025 20:07:08 +0100
Subject: [PATCH 2/3] Change constant pow special case handling from
 complex::ConstantOp to complex::CreateOp.

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 45 ++++++++--------
 .../ComplexToROCDLLibraryCalls.cpp            | 54 +++++++++++++------
 .../complex-to-rocdl-library-calls.mlir       | 12 +++--
 3 files changed, 70 insertions(+), 41 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 2f8965adfb320..2269a9b38d746 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,10 +1276,10 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
   return result;
 }
 
-mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
-                           const MathOperation &mathOp,
-                           mlir::FunctionType mathLibFuncType,
-                           llvm::ArrayRef<mlir::Value> args) {
+mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
+                          const MathOperation &mathOp,
+                          mlir::FunctionType mathLibFuncType,
+                          llvm::ArrayRef<mlir::Value> args) {
   bool canUseApprox = mlir::arith::bitEnumContainsAny(
       builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
   bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
@@ -1648,18 +1648,18 @@ static constexpr MathOperation mathOperations[] = {
      genMathOp<mlir::math::FPowIOp>},
     {"pow", RTNAME_STRING(cpowi),
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
-     genComplexPowI},
+     genComplexPow},
     {"pow", RTNAME_STRING(zpowi),
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
-     genComplexPowI},
+     genComplexPow},
     {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
      genLibF128Call},
     {"pow", RTNAME_STRING(cpowk),
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
-     genComplexPowI},
+     genComplexPow},
     {"pow", RTNAME_STRING(zpowk),
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
-     genComplexPowI},
+     genComplexPow},
     {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
      genLibF128Call},
     {"remainder", "remainderf",
@@ -4057,21 +4057,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
     mlir::Value waitAddr = fir::getBase(wait);
     mlir::Value waitIsPresentAtRuntime =
         builder.genIsNotNullAddr(loc, waitAddr);
-    waitBool = builder
-                   .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
-                            /*withElseRegion=*/true)
-                   .genThen([&]() {
-                     auto waitLoad =
-                         fir::LoadOp::create(builder, loc, waitAddr);
-                     mlir::Value cast =
-                         builder.createConvert(loc, i1Ty, waitLoad);
-                     fir::ResultOp::create(builder, loc, cast);
-                   })
-                   .genElse([&]() {
-                     mlir::Value trueVal = builder.createBool(loc, true);
-                     fir::ResultOp::create(builder, loc, trueVal);
-                   })
-                   .getResults()[0];
+    waitBool =
+        builder
+            .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
+                     /*withElseRegion=*/true)
+            .genThen([&]() {
+              auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
+              mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
+              fir::ResultOp::create(builder, loc, cast);
+            })
+            .genElse([&]() {
+              mlir::Value trueVal = builder.createBool(loc, true);
+              fir::ResultOp::create(builder, loc, trueVal);
+            })
+            .getResults()[0];
   }
 
   mlir::Value exitstatBox =
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 558fcdf782800..3bb40dd705cc2 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -58,29 +59,52 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
 };
 
 // Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8
 struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
   using OpRewritePattern<complex::PowOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(complex::PowOp op,
                                 PatternRewriter &rewriter) const final {
     auto loc = op.getLoc();
-    if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
-      ArrayAttr value = constOp.getValue();
-      if (value.size() == 2) {
-        auto real = dyn_cast<FloatAttr>(value[0]);
-        auto imag = dyn_cast<FloatAttr>(value[1]);
-        if (real && imag && imag.getValue().isZero())
-          for (int i = 2; i <= 8; ++i)
-            if (real.getValue().isExactlyValue(i)) {
-              Value base = op.getLhs();
-              Value result = base;
-              for (int j = 1; j < i; ++j)
-                result = rewriter.create<complex::MulOp>(loc, result, base);
-              rewriter.replaceOp(op, result);
-              return success();
-            }
+
+    auto peelConst = [&](Value val) -> std::optional<TypedAttr> {
+      while (val) {
+        Operation *defOp = val.getDefiningOp();
+        if (!defOp)
+          return std::nullopt;
+
+        if (auto constVal = dyn_cast<arith::ConstantOp>(defOp))
+          return dyn_cast<TypedAttr>(constVal.getValue());
+
+        if (defOp->getName().getStringRef() == "fir.convert" &&
+            defOp->getNumOperands() == 1) {
+          val = defOp->getOperand(0);
+          continue;
+        }
+        return std::nullopt;
+      }
+      return std::nullopt;
+    };
+
+    if (auto createOp = op.getRhs().getDefiningOp<complex::CreateOp>()) {
+      auto image = peelConst(createOp.getImaginary());
+      auto real = peelConst(createOp.getReal());
+      if (image && real) {
+        auto imagFloat = dyn_cast<FloatAttr>(*image);
+        if (imagFloat && imagFloat.getValue().isZero()) {
+          auto realInt = dyn_cast<IntegerAttr>(*real);
+          if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) {
+            Value base = op.getLhs();
+            Value result = base;
+            for (int i = 1; i < realInt.getInt(); ++i)
+              result = rewriter.create<complex::MulOp>(loc, result, base);
+            rewriter.replaceOp(op, result);
+            return success();
+          }
+        }
       }
     }
+
     Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
     Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
     Value exp = rewriter.create<complex::ExpOp>(loc, mul);
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index ef6ae74a45c1c..ba0dd92e20747 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
@@ -73,12 +73,18 @@ func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
     ->(complex<f32>, complex<f64>) {
   // CHECK-NOT: call @__ocml
   // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
-  %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+  %c2_i32 = arith.constant 2 : i32
+  %c2r = "fir.convert"(%c2_i32) : (i32) -> f32
+  %c2i = arith.constant 0.0 : f32
+  %c2 = complex.create %c2r, %c2i : complex<f32>
   %p2 = complex.pow %f, %c2 : complex<f32>
   // CHECK-NOT: call @__ocml
   // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
   // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
-  %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+  %c3_i32 = arith.constant 3 : i32
+  %c3r = "fir.convert"(%c3_i32) : (i32) -> f64
+  %c3i = arith.constant 0.0 : f64
+  %c3 = complex.create %c3r, %c3i : complex<f64>
   %p3 = complex.pow %d, %c3 : complex<f64>
   // CHECK: return %[[M2]], %[[M3B]]
   return %p2, %p3 : complex<f32>, complex<f64>

>From 2a4017c0cda802bb14b134c2395b6c774e94ea18 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 14 Aug 2025 13:12:30 +0100
Subject: [PATCH 3/3] Move cpow constant optimisation to Fortran lowering.

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp |  8 ++++
 flang/test/Lower/amdgcn-complex.f90           | 17 ++++++--
 flang/test/Lower/power-operator.f90           |  8 ++++
 .../ComplexToROCDLLibraryCalls.cpp            | 41 -------------------
 .../complex-to-rocdl-library-calls.mlir       | 22 ----------
 5 files changed, 29 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 2269a9b38d746..d7d63fd9f8b3b 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1280,6 +1280,14 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
                           const MathOperation &mathOp,
                           mlir::FunctionType mathLibFuncType,
                           llvm::ArrayRef<mlir::Value> args) {
+  if (auto expInt = fir::getIntIfConstant(args[1]))
+    if (*expInt >= 2 && *expInt <= 8) {
+      mlir::Value result = args[0];
+      for (int i = 1; i < *expInt; ++i)
+        result = builder.create<mlir::complex::MulOp>(loc, result, args[0]);
+      return builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+    }
+
   bool canUseApprox = mlir::arith::bitEnumContainsAny(
       builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
   bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 3d52355d3d50a..dab8cb4034883 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -18,10 +18,19 @@ subroutine cexpf_test(a, b)
    b = exp(a)
 end subroutine
 
-! CHECK-LABEL: func @_QPpow_test(
-! CHECK: complex.pow
+! CHECK-LABEL: func @_QPpow_test1(
+! CHECK: complex.mul
+! CHECK-NOT: complex.pow
 ! CHECK-NOT: fir.call @_FortranAcpowi
-subroutine pow_test(a, b)
+subroutine pow_test1(a, b)
    complex :: a, b
    a = b**2
-end subroutine pow_test
+end subroutine pow_test1
+
+! CHECK-LABEL: func @_QPpow_test2(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test2(a, b, c)
+   complex :: a, b, c
+   a = b**c
+end subroutine pow_test2
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 2a0a09e090dde..a8943a3aa8c0b 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -143,3 +143,11 @@ subroutine pow_c8_c8(x, y, z)
   ! PRECISE: call @cpow
 end subroutine
 
+! CHECK-LABEL: pow_const
+subroutine pow_const(a, b)
+  complex :: a, b
+  ! CHECK-NOT: complex.pow
+  ! CHECK-NOT: @_FortranAcpowi
+  ! CHECK-COUNT-3: complex.mul
+  a = b**4
+end subroutine
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 3bb40dd705cc2..cc0e93248a114 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -59,52 +58,12 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
 };
 
 // Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
-// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8
 struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
   using OpRewritePattern<complex::PowOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(complex::PowOp op,
                                 PatternRewriter &rewriter) const final {
     auto loc = op.getLoc();
-
-    auto peelConst = [&](Value val) -> std::optional<TypedAttr> {
-      while (val) {
-        Operation *defOp = val.getDefiningOp();
-        if (!defOp)
-          return std::nullopt;
-
-        if (auto constVal = dyn_cast<arith::ConstantOp>(defOp))
-          return dyn_cast<TypedAttr>(constVal.getValue());
-
-        if (defOp->getName().getStringRef() == "fir.convert" &&
-            defOp->getNumOperands() == 1) {
-          val = defOp->getOperand(0);
-          continue;
-        }
-        return std::nullopt;
-      }
-      return std::nullopt;
-    };
-
-    if (auto createOp = op.getRhs().getDefiningOp<complex::CreateOp>()) {
-      auto image = peelConst(createOp.getImaginary());
-      auto real = peelConst(createOp.getReal());
-      if (image && real) {
-        auto imagFloat = dyn_cast<FloatAttr>(*image);
-        if (imagFloat && imagFloat.getValue().isZero()) {
-          auto realInt = dyn_cast<IntegerAttr>(*real);
-          if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) {
-            Value base = op.getLhs();
-            Value result = base;
-            for (int i = 1; i < realInt.getInt(); ++i)
-              result = rewriter.create<complex::MulOp>(loc, result, base);
-            rewriter.replaceOp(op, result);
-            return success();
-          }
-        }
-      }
-    }
-
     Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
     Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
     Value exp = rewriter.create<complex::ExpOp>(loc, mul);
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index ba0dd92e20747..080ba4f0ff67b 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -68,28 +68,6 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
   return %r : complex<f32>
 }
 
-// CHECK-LABEL: @pow_int_caller
-func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
-    ->(complex<f32>, complex<f64>) {
-  // CHECK-NOT: call @__ocml
-  // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
-  %c2_i32 = arith.constant 2 : i32
-  %c2r = "fir.convert"(%c2_i32) : (i32) -> f32
-  %c2i = arith.constant 0.0 : f32
-  %c2 = complex.create %c2r, %c2i : complex<f32>
-  %p2 = complex.pow %f, %c2 : complex<f32>
-  // CHECK-NOT: call @__ocml
-  // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
-  // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
-  %c3_i32 = arith.constant 3 : i32
-  %c3r = "fir.convert"(%c3_i32) : (i32) -> f64
-  %c3i = arith.constant 0.0 : f64
-  %c3 = complex.create %c3r, %c3i : complex<f64>
-  %p3 = complex.pow %d, %c3 : complex<f64>
-  // CHECK: return %[[M2]], %[[M3B]]
-  return %p2, %p3 : complex<f32>, complex<f64>
-}
-
 //CHECK-LABEL: @sin_caller
 func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})



More information about the flang-commits mailing list