[flang-commits] [flang] [mlir] [MLIR] Add cpow support in ComplexToROCDLLibraryCalls (PR #153183)
Akash Banerjee via flang-commits
flang-commits at lists.llvm.org
Tue Aug 19 08:14:25 PDT 2025
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/153183
>From 57181a27698fba69200ec20af6c5743acdc57f3a Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 12 Aug 2025 14:06:38 +0100
Subject: [PATCH 1/4] [MLIR] Add cpow support in ComplexToROCDLLibraryCalls
This PR contributes the following changes:
1. Force lowering to complex.pow ops for the amdgcn-amd-amdhsa target.
2. Convert complex.pow(z, w) -> complex.exp(w * complex.log(z)).
3. Convert x ** 2 -> x * x, x ** 3 -> x * x * x, ... x ** 8 -> x * x... .
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 34 +++++++++++++++--
flang/test/Lower/amdgcn-complex.f90 | 22 +++++++----
flang/test/Lower/power-operator.f90 | 12 ++++--
.../ComplexToROCDLLibraryCalls.cpp | 38 ++++++++++++++++++-
.../complex-to-rocdl-library-calls.mlir | 27 +++++++++++++
5 files changed, 115 insertions(+), 18 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 22193f0de88a1..74279a7d72078 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
+mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
+ const MathOperation &mathOp,
+ mlir::FunctionType mathLibFuncType,
+ llvm::ArrayRef<mlir::Value> args) {
+ bool canUseApprox = mlir::arith::bitEnumContainsAny(
+ builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
+ bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+ if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
+ return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+ auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ mlir::Value complexExp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result =
+ builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+ return result;
+}
+
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowi),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowk),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
{"remainder", "remainderf",
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index f15c7db2b7316..3d52355d3d50a 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -1,21 +1,27 @@
! REQUIRES: amdgpu-registered-target
-! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
+! CHECK-LABEL: func @_QPcabsf_test(
+! CHECK: complex.abs
+! CHECK-NOT: fir.call @cabsf
subroutine cabsf_test(a, b)
complex :: a
real :: b
b = abs(a)
end subroutine
-! CHECK-LABEL: func @_QPcabsf_test(
-! CHECK: complex.abs
-! CHECK-NOT: fir.call @cabsf
-
+! CHECK-LABEL: func @_QPcexpf_test(
+! CHECK: complex.exp
+! CHECK-NOT: fir.call @cexpf
subroutine cexpf_test(a, b)
complex :: a, b
b = exp(a)
end subroutine
-! CHECK-LABEL: func @_QPcexpf_test(
-! CHECK: complex.exp
-! CHECK-NOT: fir.call @cexpf
+! CHECK-LABEL: func @_QPpow_test(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test(a, b)
+ complex :: a, b
+ a = b**2
+end subroutine pow_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..2a0a09e090dde 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! PRECISE: call @_FortranAcpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! PRECISE: call @_FortranAcpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! PRECISE: call @_FortranAzpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,7 +123,8 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! PRECISE: call @_FortranAzpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c4_c4
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index b3d6d59e25bd0..558fcdf782800 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,10 +56,43 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+ using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowOp op,
+ PatternRewriter &rewriter) const final {
+ auto loc = op.getLoc();
+ if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr value = constOp.getValue();
+ if (value.size() == 2) {
+ auto real = dyn_cast<FloatAttr>(value[0]);
+ auto imag = dyn_cast<FloatAttr>(value[1]);
+ if (real && imag && imag.getValue().isZero())
+ for (int i = 2; i <= 8; ++i)
+ if (real.getValue().isExactlyValue(i)) {
+ Value base = op.getLhs();
+ Value result = base;
+ for (int j = 1; j < i; ++j)
+ result = rewriter.create<complex::MulOp>(loc, result, base);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+ }
+ }
+ Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+ Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+ Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+ rewriter.replaceOp(op, exp);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +143,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
+ target.addLegalOp<complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::SinOp, complex::SqrtOp,
- complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::SinOp,
+ complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 82936d89e8ac1..ef6ae74a45c1c 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -57,6 +57,33 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
return %lf, %ld : complex<f32>, complex<f64>
}
+//CHECK-LABEL: @pow_caller
+//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
+func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
+ // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
+ // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
+ // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+ // CHECK: return %[[EXP]]
+ %r = complex.pow %z, %w : complex<f32>
+ return %r : complex<f32>
+}
+
+// CHECK-LABEL: @pow_int_caller
+func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
+ ->(complex<f32>, complex<f64>) {
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+ %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+ %p2 = complex.pow %f, %c2 : complex<f32>
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
+ // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
+ %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+ %p3 = complex.pow %d, %c3 : complex<f64>
+ // CHECK: return %[[M2]], %[[M3B]]
+ return %p2, %p3 : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
>From 9f0145439c29f91a0b71f39675cd26603a585e20 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 12 Aug 2025 20:07:08 +0100
Subject: [PATCH 2/4] Change constant pow special case handling from
complex::ConstantOp to complex::CreateOp.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 45 ++++++++--------
.../ComplexToROCDLLibraryCalls.cpp | 54 +++++++++++++------
.../complex-to-rocdl-library-calls.mlir | 12 +++--
3 files changed, 70 insertions(+), 41 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 74279a7d72078..89866bb143fba 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,10 +1276,10 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
-mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
- const MathOperation &mathOp,
- mlir::FunctionType mathLibFuncType,
- llvm::ArrayRef<mlir::Value> args) {
+mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
+ const MathOperation &mathOp,
+ mlir::FunctionType mathLibFuncType,
+ llvm::ArrayRef<mlir::Value> args) {
bool canUseApprox = mlir::arith::bitEnumContainsAny(
builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
@@ -1648,18 +1648,18 @@ static constexpr MathOperation mathOperations[] = {
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
- genComplexPowI},
+ genComplexPow},
{"pow", RTNAME_STRING(zpowi),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
- genComplexPowI},
+ genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
- genComplexPowI},
+ genComplexPow},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
- genComplexPowI},
+ genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
{"remainder", "remainderf",
@@ -4058,21 +4058,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
mlir::Value waitAddr = fir::getBase(wait);
mlir::Value waitIsPresentAtRuntime =
builder.genIsNotNullAddr(loc, waitAddr);
- waitBool = builder
- .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
- /*withElseRegion=*/true)
- .genThen([&]() {
- auto waitLoad =
- fir::LoadOp::create(builder, loc, waitAddr);
- mlir::Value cast =
- builder.createConvert(loc, i1Ty, waitLoad);
- fir::ResultOp::create(builder, loc, cast);
- })
- .genElse([&]() {
- mlir::Value trueVal = builder.createBool(loc, true);
- fir::ResultOp::create(builder, loc, trueVal);
- })
- .getResults()[0];
+ waitBool =
+ builder
+ .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
+ mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
+ fir::ResultOp::create(builder, loc, cast);
+ })
+ .genElse([&]() {
+ mlir::Value trueVal = builder.createBool(loc, true);
+ fir::ResultOp::create(builder, loc, trueVal);
+ })
+ .getResults()[0];
}
mlir::Value exitstatBox =
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 558fcdf782800..3bb40dd705cc2 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -58,29 +59,52 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
};
// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
using OpRewritePattern<complex::PowOp>::OpRewritePattern;
LogicalResult matchAndRewrite(complex::PowOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
- if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
- ArrayAttr value = constOp.getValue();
- if (value.size() == 2) {
- auto real = dyn_cast<FloatAttr>(value[0]);
- auto imag = dyn_cast<FloatAttr>(value[1]);
- if (real && imag && imag.getValue().isZero())
- for (int i = 2; i <= 8; ++i)
- if (real.getValue().isExactlyValue(i)) {
- Value base = op.getLhs();
- Value result = base;
- for (int j = 1; j < i; ++j)
- result = rewriter.create<complex::MulOp>(loc, result, base);
- rewriter.replaceOp(op, result);
- return success();
- }
+
+ auto peelConst = [&](Value val) -> std::optional<TypedAttr> {
+ while (val) {
+ Operation *defOp = val.getDefiningOp();
+ if (!defOp)
+ return std::nullopt;
+
+ if (auto constVal = dyn_cast<arith::ConstantOp>(defOp))
+ return dyn_cast<TypedAttr>(constVal.getValue());
+
+ if (defOp->getName().getStringRef() == "fir.convert" &&
+ defOp->getNumOperands() == 1) {
+ val = defOp->getOperand(0);
+ continue;
+ }
+ return std::nullopt;
+ }
+ return std::nullopt;
+ };
+
+ if (auto createOp = op.getRhs().getDefiningOp<complex::CreateOp>()) {
+ auto image = peelConst(createOp.getImaginary());
+ auto real = peelConst(createOp.getReal());
+ if (image && real) {
+ auto imagFloat = dyn_cast<FloatAttr>(*image);
+ if (imagFloat && imagFloat.getValue().isZero()) {
+ auto realInt = dyn_cast<IntegerAttr>(*real);
+ if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) {
+ Value base = op.getLhs();
+ Value result = base;
+ for (int i = 1; i < realInt.getInt(); ++i)
+ result = rewriter.create<complex::MulOp>(loc, result, base);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+ }
}
}
+
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
Value exp = rewriter.create<complex::ExpOp>(loc, mul);
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index ef6ae74a45c1c..ba0dd92e20747 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
@@ -73,12 +73,18 @@ func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
->(complex<f32>, complex<f64>) {
// CHECK-NOT: call @__ocml
// CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
- %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+ %c2_i32 = arith.constant 2 : i32
+ %c2r = "fir.convert"(%c2_i32) : (i32) -> f32
+ %c2i = arith.constant 0.0 : f32
+ %c2 = complex.create %c2r, %c2i : complex<f32>
%p2 = complex.pow %f, %c2 : complex<f32>
// CHECK-NOT: call @__ocml
// CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
// CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
- %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+ %c3_i32 = arith.constant 3 : i32
+ %c3r = "fir.convert"(%c3_i32) : (i32) -> f64
+ %c3i = arith.constant 0.0 : f64
+ %c3 = complex.create %c3r, %c3i : complex<f64>
%p3 = complex.pow %d, %c3 : complex<f64>
// CHECK: return %[[M2]], %[[M3B]]
return %p2, %p3 : complex<f32>, complex<f64>
>From dce85c3dbb3d503c945b6088260bb572d05f1933 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 14 Aug 2025 13:12:30 +0100
Subject: [PATCH 3/4] Move cpow constant optimisation to Fortran lowering.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 8 ++++
flang/test/Lower/amdgcn-complex.f90 | 17 ++++++--
flang/test/Lower/power-operator.f90 | 8 ++++
.../ComplexToROCDLLibraryCalls.cpp | 41 -------------------
.../complex-to-rocdl-library-calls.mlir | 22 ----------
5 files changed, 29 insertions(+), 67 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 89866bb143fba..a424007eee799 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1280,6 +1280,14 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
+ if (auto expInt = fir::getIntIfConstant(args[1]))
+ if (*expInt >= 2 && *expInt <= 8) {
+ mlir::Value result = args[0];
+ for (int i = 1; i < *expInt; ++i)
+ result = builder.create<mlir::complex::MulOp>(loc, result, args[0]);
+ return builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+ }
+
bool canUseApprox = mlir::arith::bitEnumContainsAny(
builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 3d52355d3d50a..dab8cb4034883 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -18,10 +18,19 @@ subroutine cexpf_test(a, b)
b = exp(a)
end subroutine
-! CHECK-LABEL: func @_QPpow_test(
-! CHECK: complex.pow
+! CHECK-LABEL: func @_QPpow_test1(
+! CHECK: complex.mul
+! CHECK-NOT: complex.pow
! CHECK-NOT: fir.call @_FortranAcpowi
-subroutine pow_test(a, b)
+subroutine pow_test1(a, b)
complex :: a, b
a = b**2
-end subroutine pow_test
+end subroutine pow_test1
+
+! CHECK-LABEL: func @_QPpow_test2(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test2(a, b, c)
+ complex :: a, b, c
+ a = b**c
+end subroutine pow_test2
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 2a0a09e090dde..a8943a3aa8c0b 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -143,3 +143,11 @@ subroutine pow_c8_c8(x, y, z)
! PRECISE: call @cpow
end subroutine
+! CHECK-LABEL: pow_const
+subroutine pow_const(a, b)
+ complex :: a, b
+ ! CHECK-NOT: complex.pow
+ ! CHECK-NOT: @_FortranAcpowi
+ ! CHECK-COUNT-3: complex.mul
+ a = b**4
+end subroutine
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 3bb40dd705cc2..cc0e93248a114 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -59,52 +58,12 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
};
// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
-// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
using OpRewritePattern<complex::PowOp>::OpRewritePattern;
LogicalResult matchAndRewrite(complex::PowOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
-
- auto peelConst = [&](Value val) -> std::optional<TypedAttr> {
- while (val) {
- Operation *defOp = val.getDefiningOp();
- if (!defOp)
- return std::nullopt;
-
- if (auto constVal = dyn_cast<arith::ConstantOp>(defOp))
- return dyn_cast<TypedAttr>(constVal.getValue());
-
- if (defOp->getName().getStringRef() == "fir.convert" &&
- defOp->getNumOperands() == 1) {
- val = defOp->getOperand(0);
- continue;
- }
- return std::nullopt;
- }
- return std::nullopt;
- };
-
- if (auto createOp = op.getRhs().getDefiningOp<complex::CreateOp>()) {
- auto image = peelConst(createOp.getImaginary());
- auto real = peelConst(createOp.getReal());
- if (image && real) {
- auto imagFloat = dyn_cast<FloatAttr>(*image);
- if (imagFloat && imagFloat.getValue().isZero()) {
- auto realInt = dyn_cast<IntegerAttr>(*real);
- if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) {
- Value base = op.getLhs();
- Value result = base;
- for (int i = 1; i < realInt.getInt(); ++i)
- result = rewriter.create<complex::MulOp>(loc, result, base);
- rewriter.replaceOp(op, result);
- return success();
- }
- }
- }
- }
-
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
Value exp = rewriter.create<complex::ExpOp>(loc, mul);
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index ba0dd92e20747..080ba4f0ff67b 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -68,28 +68,6 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
return %r : complex<f32>
}
-// CHECK-LABEL: @pow_int_caller
-func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
- ->(complex<f32>, complex<f64>) {
- // CHECK-NOT: call @__ocml
- // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
- %c2_i32 = arith.constant 2 : i32
- %c2r = "fir.convert"(%c2_i32) : (i32) -> f32
- %c2i = arith.constant 0.0 : f32
- %c2 = complex.create %c2r, %c2i : complex<f32>
- %p2 = complex.pow %f, %c2 : complex<f32>
- // CHECK-NOT: call @__ocml
- // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
- // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
- %c3_i32 = arith.constant 3 : i32
- %c3r = "fir.convert"(%c3_i32) : (i32) -> f64
- %c3i = arith.constant 0.0 : f64
- %c3 = complex.create %c3r, %c3i : complex<f64>
- %p3 = complex.pow %d, %c3 : complex<f64>
- // CHECK: return %[[M2]], %[[M3B]]
- return %p2, %p3 : complex<f32>, complex<f64>
-}
-
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
>From 0cebae8de1fb432150811f319cbcd12522c2321e Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 19 Aug 2025 16:07:57 +0100
Subject: [PATCH 4/4] Remove constant optimisation.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 8 --------
flang/test/Lower/amdgcn-complex.f90 | 15 +++------------
flang/test/Lower/power-operator.f90 | 9 ---------
3 files changed, 3 insertions(+), 29 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a424007eee799..89866bb143fba 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1280,14 +1280,6 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
- if (auto expInt = fir::getIntIfConstant(args[1]))
- if (*expInt >= 2 && *expInt <= 8) {
- mlir::Value result = args[0];
- for (int i = 1; i < *expInt; ++i)
- result = builder.create<mlir::complex::MulOp>(loc, result, args[0]);
- return builder.createConvert(loc, mathLibFuncType.getResult(0), result);
- }
-
bool canUseApprox = mlir::arith::bitEnumContainsAny(
builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index dab8cb4034883..4ee5de4d2842e 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -18,19 +18,10 @@ subroutine cexpf_test(a, b)
b = exp(a)
end subroutine
-! CHECK-LABEL: func @_QPpow_test1(
-! CHECK: complex.mul
-! CHECK-NOT: complex.pow
-! CHECK-NOT: fir.call @_FortranAcpowi
-subroutine pow_test1(a, b)
- complex :: a, b
- a = b**2
-end subroutine pow_test1
-
-! CHECK-LABEL: func @_QPpow_test2(
+! CHECK-LABEL: func @_QPpow_test(
! CHECK: complex.pow
! CHECK-NOT: fir.call @_FortranAcpowi
-subroutine pow_test2(a, b, c)
+subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
-end subroutine pow_test2
+end subroutine pow_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index a8943a3aa8c0b..ebce4f52d449d 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -142,12 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: call @cpow
end subroutine
-
-! CHECK-LABEL: pow_const
-subroutine pow_const(a, b)
- complex :: a, b
- ! CHECK-NOT: complex.pow
- ! CHECK-NOT: @_FortranAcpowi
- ! CHECK-COUNT-3: complex.mul
- a = b**4
-end subroutine
More information about the flang-commits
mailing list