[flang-commits] [flang] d69ccde - [MLIR] Add cpow support in ComplexToROCDLLibraryCalls (#153183)

via flang-commits flang-commits at lists.llvm.org
Wed Aug 20 10:18:34 PDT 2025


Author: Akash Banerjee
Date: 2025-08-20T17:18:30Z
New Revision: d69ccded4ff14644245990e4ecc4f96e9610dd3d

URL: https://github.com/llvm/llvm-project/commit/d69ccded4ff14644245990e4ecc4f96e9610dd3d
DIFF: https://github.com/llvm/llvm-project/commit/d69ccded4ff14644245990e4ecc4f96e9610dd3d.diff

LOG: [MLIR] Add cpow support in ComplexToROCDLLibraryCalls (#153183)

This PR adds support for complex power operations (`cpow`) in the
`ComplexToROCDLLibraryCalls` conversion pass, specifically targeting
AMDGPU architectures. The implementation optimises complex
exponentiation by using mathematical identities and special-case
handling for small integer powers.

- Force lowering to `complex.pow` operations for the `amdgcn-amd-amdhsa`
target instead of using library calls
- Convert `complex.pow(z, w)` to `complex.exp(w * complex.log(z))` using
mathematical identity

Added: 
    

Modified: 
    flang/lib/Optimizer/Builder/IntrinsicCall.cpp
    flang/test/Lower/amdgcn-complex.f90
    mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
    mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8aacdb1815e3b..cc3ae31ff4668 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1287,6 +1287,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
   return result;
 }
 
+mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
+                          const MathOperation &mathOp,
+                          mlir::FunctionType mathLibFuncType,
+                          llvm::ArrayRef<mlir::Value> args) {
+  bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+  if (!isAMDGPU)
+    return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+  auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+  auto realTy = complexTy.getElementType();
+  mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+  mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+  mlir::Value complexExp =
+      builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  mlir::Value result =
+      builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+  result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+  return result;
+}
+
 /// Mapping between mathematical intrinsic operations and MLIR operations
 /// of some appropriate dialect (math, complex, etc.) or libm calls.
 /// TODO: support remaining Fortran math intrinsics.
@@ -1636,15 +1656,19 @@ static constexpr MathOperation mathOperations[] = {
      genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
      genMathOp<mlir::math::FPowIOp>},
     {"pow", RTNAME_STRING(cpowi),
-     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+     genComplexPow},
     {"pow", RTNAME_STRING(zpowi),
-     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+     genComplexPow},
     {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
      genLibF128Call},
     {"pow", RTNAME_STRING(cpowk),
-     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+     genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+     genComplexPow},
     {"pow", RTNAME_STRING(zpowk),
-     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+     genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+     genComplexPow},
     {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
      genLibF128Call},
     {"remainder", "remainderf",
@@ -4044,21 +4068,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
     mlir::Value waitAddr = fir::getBase(wait);
     mlir::Value waitIsPresentAtRuntime =
         builder.genIsNotNullAddr(loc, waitAddr);
-    waitBool = builder
-                   .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
-                            /*withElseRegion=*/true)
-                   .genThen([&]() {
-                     auto waitLoad =
-                         fir::LoadOp::create(builder, loc, waitAddr);
-                     mlir::Value cast =
-                         builder.createConvert(loc, i1Ty, waitLoad);
-                     fir::ResultOp::create(builder, loc, cast);
-                   })
-                   .genElse([&]() {
-                     mlir::Value trueVal = builder.createBool(loc, true);
-                     fir::ResultOp::create(builder, loc, trueVal);
-                   })
-                   .getResults()[0];
+    waitBool =
+        builder
+            .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
+                     /*withElseRegion=*/true)
+            .genThen([&]() {
+              auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
+              mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
+              fir::ResultOp::create(builder, loc, cast);
+            })
+            .genElse([&]() {
+              mlir::Value trueVal = builder.createBool(loc, true);
+              fir::ResultOp::create(builder, loc, trueVal);
+            })
+            .getResults()[0];
   }
 
   mlir::Value exitstatBox =

diff  --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index f15c7db2b7316..4ee5de4d2842e 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -1,21 +1,27 @@
 ! REQUIRES: amdgpu-registered-target
-! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
 
+! CHECK-LABEL: func @_QPcabsf_test(
+! CHECK: complex.abs
+! CHECK-NOT: fir.call @cabsf
 subroutine cabsf_test(a, b)
    complex :: a
    real :: b
    b = abs(a)
 end subroutine
 
-! CHECK-LABEL: func @_QPcabsf_test(
-! CHECK: complex.abs
-! CHECK-NOT: fir.call @cabsf
-
+! CHECK-LABEL: func @_QPcexpf_test(
+! CHECK: complex.exp
+! CHECK-NOT: fir.call @cexpf
 subroutine cexpf_test(a, b)
    complex :: a, b
    b = exp(a)
 end subroutine
 
-! CHECK-LABEL: func @_QPcexpf_test(
-! CHECK: complex.exp
-! CHECK-NOT: fir.call @cexpf
+! CHECK-LABEL: func @_QPpow_test(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test(a, b, c)
+   complex :: a, b, c
+   a = b**c
+end subroutine pow_test

diff  --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index b3d6d59e25bd0..7a3a7fdb73e5e 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,10 +56,26 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
 private:
   std::string funcName;
 };
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+  using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+    Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+    Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+    rewriter.replaceOp(op, exp);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
+  target.addLegalOp<complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::SinOp, complex::SqrtOp,
-                      complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::SinOp,
+                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 82936d89e8ac1..080ba4f0ff67b 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
@@ -57,6 +57,17 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
   return %lf, %ld : complex<f32>, complex<f64>
 }
 
+//CHECK-LABEL: @pow_caller
+//CHECK:          (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
+func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
+  // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
+  // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
+  // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+  // CHECK: return %[[EXP]]
+  %r = complex.pow %z, %w : complex<f32>
+  return %r : complex<f32>
+}
+
 //CHECK-LABEL: @sin_caller
 func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})


        


More information about the flang-commits mailing list