[llvm-branch-commits] [flang] [mlir] [MLIR] Add new complex.powi op (PR #158722)
Akash Banerjee via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Sep 18 08:20:22 PDT 2025
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158722
>From 6976910364aa2fe18603aefcb27b10bd0120513d Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Mon, 15 Sep 2025 20:35:29 +0100
Subject: [PATCH 1/6] Add complex.powi op.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 20 ++--
.../Transforms/ConvertComplexPow.cpp | 94 +++++++++----------
flang/test/Lower/HLFIR/binary-ops.f90 | 2 +-
.../test/Lower/Intrinsics/pow_complex16i.f90 | 2 +-
.../test/Lower/Intrinsics/pow_complex16k.f90 | 2 +-
flang/test/Lower/amdgcn-complex.f90 | 9 ++
flang/test/Lower/power-operator.f90 | 9 +-
.../mlir/Dialect/Complex/IR/ComplexOps.td | 26 +++++
.../ComplexToROCDLLibraryCalls.cpp | 41 +++++++-
.../Transforms/AlgebraicSimplification.cpp | 24 +++--
.../Dialect/Math/Transforms/CMakeLists.txt | 1 +
.../complex-to-rocdl-library-calls.mlir | 14 +++
mlir/test/Dialect/Complex/powi-simplify.mlir | 20 ++++
13 files changed, 188 insertions(+), 76 deletions(-)
create mode 100644 mlir/test/Dialect/Complex/powi-simplify.mlir
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
- if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, exp);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- exp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result;
+ if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+ mlir::isa<mlir::IndexType>(exp.getType())) {
+ result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+ } else {
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+ zero);
+ }
+ result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
}
- mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 78f9d9e4f639a..d76451459def9 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -58,63 +58,57 @@ void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
- mod.walk([&](complex::PowOp op) {
+ mod.walk([&](complex::PowiOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();
-
Value base = op.getLhs();
- Value rhs = op.getRhs();
-
- Value intExp;
- if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
- if (isZero(create.getImaginary())) {
- if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
- if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
- intExp = conv.getValue();
- }
- }
- }
-
+ Value intExp = op.getRhs();
func::FuncOp callee;
- SmallVector<Value> args;
- if (intExp) {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
- auto funcTy = builder.getFunctionType(
- {complexTy, builder.getIntegerType(intBits)}, {complexTy});
- if (realBits == 32 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
- else if (realBits == 32 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
- else if (realBits == 64 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
- else if (realBits == 64 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
- else if (realBits == 128 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
- else if (realBits == 128 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
- else
- return;
- args = {base, intExp};
- } else {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- auto funcTy =
- builder.getFunctionType({complexTy, complexTy}, {complexTy});
- if (realBits == 32)
- callee = getOrDeclare(builder, loc, "cpowf", funcTy);
- else if (realBits == 64)
- callee = getOrDeclare(builder, loc, "cpow", funcTy);
- else if (realBits == 128)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
- else
- return;
- args = {base, rhs};
- }
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+ if (auto fmf = op.getFastmathAttr())
+ call.setFastmathAttr(fmf);
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
- auto call = fir::CallOp::create(builder, loc, callee, args);
+ mod.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ func::FuncOp callee;
+ auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ auto call =
+ fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
if (auto fmf = op.getFastmathAttr())
call.setFastmathAttr(fmf);
op.replaceAllUsesWith(call.getResult(0));
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = complex.pow
+! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+ complex :: a, b
+ integer :: i
+ b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
! PRECISE: fir.call @_FortranAcpowi
end subroutine
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
! PRECISE: fir.call @_FortranAcpowk
end subroutine
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
! PRECISE: fir.call @_FortranAzpowi
end subroutine
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
! PRECISE: fir.call @_FortranAzpowk
end subroutine
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine
-
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
}];
}
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+ [Pure, Elementwise, SameOperandsAndResultShape,
+ AllTypesMatch<["lhs", "result"]>]> {
+ let summary = "complex number raised to integer power";
+ let description = [{
+ The `powi` operation takes a complex number and an integer exponent.
+
+ Example:
+
+ ```mlir
+ %a = complex.powi %b, %c : complex<f32>, i32
+ ```
+ }];
+
+ let arguments = (ins Complex<AnyFloat>:$lhs,
+ AnySignlessInteger:$rhs);
+ let results = (outs Complex<AnyFloat>:$result);
+
+ let assemblyFormat =
+ "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 72b1fa6e833f9..361e422ce1468 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -74,10 +76,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
return success();
}
};
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+ using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowiOp op,
+ PatternRewriter &rewriter) const final {
+ auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+ Type elementType = complexType.getElementType();
+
+ Type exponentType = op.getRhs().getType();
+ Type exponentFloatType = elementType;
+ if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+ exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+ Location loc = op.getLoc();
+ Value exponentReal =
+ rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+ Value zeroImag = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = rewriter.create<complex::CreateOp>(
+ loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+ rewriter
+ .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+ exponent);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
@@ -128,11 +160,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<func::FuncDialect>();
- target.addLegalOp<complex::MulOp>();
+ target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+ target.addLegalOp<complex::CreateOp, complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::PowOp, complex::SinOp,
- complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::PowiOp,
+ complex::SinOp, complex::SqrtOp, complex::TanOp,
+ complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
Value one;
Type opType = getElementTypeOrSelf(op.getType());
- if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+ if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getFloatAttr(opType, 1.0));
- else
+ } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+ auto complexTy = cast<ComplexType>(opType);
+ Type elementType = complexTy.getElementType();
+ auto realPart = rewriter.getFloatAttr(elementType, 1.0);
+ auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
+ one = rewriter.create<complex::ConstantOp>(
+ loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
+ } else {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(opType, 1));
+ }
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -224,9 +233,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<PowFStrengthReduction,
- PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
- PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
- patterns.getContext());
+ patterns.add<
+ PowFStrengthReduction,
+ PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
+ PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
+ PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
+ patterns.getContext(), /*exponentThreshold=*/8);
}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index d37a056e8e158..ff62b515533c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRComplexDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 080ba4f0ff67b..cf177528e532c 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
return %r : complex<f32>
}
+//CHECK-LABEL: @powi_caller
+//CHECK: (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32)
+func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> {
+ // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32
+ // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32>
+ // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32>
+ // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32>
+ // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32>
+ // CHECK: return %[[EXP]] : complex<f32>
+ %r = complex.powi %z, %n : complex<f32>, i32
+ return %r : complex<f32>
+}
+
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir b/mlir/test/Dialect/Complex/powi-simplify.mlir
new file mode 100644
index 0000000000000..c7bb6a9d81479
--- /dev/null
+++ b/mlir/test/Dialect/Complex/powi-simplify.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s
+
+func.func @pow3(%arg0: complex<f32>) -> complex<f32> {
+ %c3 = arith.constant 3 : i32
+ %0 = complex.powi %arg0, %c3 : complex<f32>, i32
+ return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow3(
+// CHECK-NOT: complex.powi
+// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32>
+// CHECK: return %[[M1]] : complex<f32>
+
+func.func @pow9(%arg0: complex<f32>) -> complex<f32> {
+ %c9 = arith.constant 9 : i32
+ %0 = complex.powi %arg0, %c9 : complex<f32>, i32
+ return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow9(
+// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
>From 8f71488583c15d68c5fd2bf6e86a280698f09624 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Mon, 15 Sep 2025 20:47:37 +0100
Subject: [PATCH 2/6] Fix clang-format.
---
.../ComplexToROCDLLibraryCalls.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 361e422ce1468..dbb26377fc3c4 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -99,9 +99,8 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
Value exponent = rewriter.create<complex::CreateOp>(
loc, op.getLhs().getType(), exponentReal, zeroImag);
- rewriter
- .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
- exponent);
+ rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+ exponent);
return success();
}
};
>From 52182f113bde37682749b5b8723a2bd7802300bf Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Mon, 15 Sep 2025 20:57:37 +0100
Subject: [PATCH 3/6] Remove unused function.
---
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 -------
1 file changed, 7 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index d76451459def9..1c251883cf707 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -47,13 +47,6 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
return func;
}
-static bool isZero(Value v) {
- if (auto cst = v.getDefiningOp<arith::ConstantOp>())
- if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
- return attr.getValue().isZero();
- return false;
-}
-
void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
>From 6faad70f8ef516996ccd4436e7c7cc3ec29310f6 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Wed, 17 Sep 2025 22:15:26 +0100
Subject: [PATCH 4/6] Add fastmath attribute. Update op description. Update
tests.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 10 ++--
flang/test/Lower/HLFIR/binary-ops.f90 | 2 +-
.../test/Lower/Intrinsics/pow_complex16i.f90 | 2 +-
.../test/Lower/Intrinsics/pow_complex16k.f90 | 2 +-
flang/test/Transforms/convert-complex-pow.fir | 60 +++++++++----------
.../mlir/Dialect/Complex/IR/ComplexOps.td | 14 +++--
.../ComplexToROCDLLibraryCalls.cpp | 2 +-
.../Transforms/AlgebraicSimplification.cpp | 18 +++++-
8 files changed, 63 insertions(+), 47 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 74a4e8f85c8ff..c7cbf162db786 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1332,9 +1332,11 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
mlir::Value result;
- if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
- mlir::isa<mlir::IndexType>(exp.getType())) {
- result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+ auto fmfAttr = mlir::arith::FastMathFlagsAttr::get(
+ builder.getContext(), builder.getFastMathFlags());
+ if (mlir::isa<mlir::IntegerType>(exp.getType())) {
+ result = builder.create<mlir::complex::PowiOp>(
+ loc, mathLibFuncType.getResult(0), args[0], args[1], fmfAttr);
} else {
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
auto realTy = complexTy.getElementType();
@@ -1343,7 +1345,7 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
zero);
}
- result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
+ result = builder.create<mlir::complex::PowOp>(loc, args[0], exp, fmfAttr);
}
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 7e1691dd1587a..b7695a761a0b8 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
+! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 0b26024b02021..ea18d67b75460 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 90a9f5e03628d..d2b70185bda9f 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index e09fa7316c4b0..23316ed46d40f 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,51 +2,38 @@
module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
- %c0 = arith.constant 0.0 : f32
- %0 = fir.convert %arg1 : (i32) -> f32
- %1 = complex.create %0, %c0 : complex<f32>
- %2 = complex.pow %arg0, %1 : complex<f32>
- return %2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
+ return %0 : complex<f32>
+ }
+
+ func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+ %0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
+ return %0 : complex<f32>
}
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
- %c0 = arith.constant 0.0 : f32
- %0 = fir.convert %arg1 : (i64) -> f32
- %1 = complex.create %0, %c0 : complex<f32>
- %2 = complex.pow %arg0, %1 : complex<f32>
- return %2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
+ return %0 : complex<f32>
}
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
- %c0 = arith.constant 0.0 : f64
- %0 = fir.convert %arg1 : (i32) -> f64
- %1 = complex.create %0, %c0 : complex<f64>
- %2 = complex.pow %arg0, %1 : complex<f64>
- return %2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
+ return %0 : complex<f64>
}
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
- %c0 = arith.constant 0.0 : f64
- %0 = fir.convert %arg1 : (i64) -> f64
- %1 = complex.create %0, %c0 : complex<f64>
- %2 = complex.pow %arg0, %1 : complex<f64>
- return %2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
+ return %0 : complex<f64>
}
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
- %c0 = arith.constant 0.0 : f128
- %0 = fir.convert %arg1 : (i32) -> f128
- %1 = complex.create %0, %c0 : complex<f128>
- %2 = complex.pow %arg0, %1 : complex<f128>
- return %2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
+ return %0 : complex<f128>
}
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
- %c0 = arith.constant 0.0 : f128
- %0 = fir.convert %arg1 : (i64) -> f128
- %1 = complex.create %0, %c0 : complex<f128>
- %2 = complex.pow %arg0, %1 : complex<f128>
- return %2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
+ return %0 : complex<f128>
}
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
@@ -74,26 +61,37 @@ module {
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
+
+// CHECK-LABEL: func.func @pow_c4_i4_fast(
+// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
+// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_fast(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
@@ -108,4 +106,4 @@ module {
// CHECK-LABEL: func.func @pow_c16_complex(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
-// CHECK-NOT: complex.pow
\ No newline at end of file
+// CHECK-NOT: complex.pow
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index ca5103c16889c..828379ded14b3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -449,10 +449,13 @@ def PowOp : ComplexArithmeticOp<"pow"> {
def PowiOp : Complex_Op<"powi",
[Pure, Elementwise, SameOperandsAndResultShape,
- AllTypesMatch<["lhs", "result"]>]> {
- let summary = "complex number raised to integer power";
+ AllTypesMatch<["lhs", "result"]>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let summary = "complex number raised to signed integer power";
let description = [{
- The `powi` operation takes a complex number and an integer exponent.
+ The `powi` operation takes a `base` operand of complex type and a `power`
+ operand of signed integer type and returns one result of the same type
+ as `base`. The result is `base` raised to the power of `power`.
Example:
@@ -462,11 +465,12 @@ def PowiOp : Complex_Op<"powi",
}];
let arguments = (ins Complex<AnyFloat>:$lhs,
- AnySignlessInteger:$rhs);
+ AnySignlessInteger:$rhs,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath);
let results = (outs Complex<AnyFloat>:$result);
let assemblyFormat =
- "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+ "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index dbb26377fc3c4..42099aaa6b574 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -100,7 +100,7 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
loc, op.getLhs().getType(), exponentReal, zeroImag);
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
- exponent);
+ exponent, op.getFastmathAttr());
return success();
}
};
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 3711c112cc631..fffccf130a571 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -217,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// `[fi]powi(x, negative_exponent)`
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
+ auto buildMul = [&](Value lhs, Value rhs) {
+ if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
+ return rewriter.create<MulOpTy>(loc, op.getType(), lhs, rhs,
+ op.getFastmathAttr());
+ else
+ return MulOpTy::create(rewriter, loc, lhs, rhs);
+ };
for (unsigned i = 1; i < exponentValue; ++i)
- result = MulOpTy::create(rewriter, loc, result, base);
+ result = buildMul(result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
- if (exponentIsNegative)
- result = DivOpTy::create(rewriter, loc, bcast(one), result);
+ if (exponentIsNegative) {
+ if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
+ result = rewriter.create<DivOpTy>(loc, op.getType(), bcast(one), result,
+ op.getFastmathAttr());
+ else
+ result = DivOpTy::create(rewriter, loc, bcast(one), result);
+ }
rewriter.replaceOp(op, result);
return success();
>From f659924576565e89b98ad381a8fd54020515592b Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Wed, 17 Sep 2025 23:04:41 +0100
Subject: [PATCH 5/6] Remove genComplexPow, use genMathOp instead. Add
complex.powi->complex.pow conversion in ComplexToStandard pass.
---
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 59 +++++++------------
.../ComplexToStandard/ComplexToStandard.cpp | 25 ++++++++
.../convert-to-standard.mlir | 30 ++++++++++
3 files changed, 76 insertions(+), 38 deletions(-)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index c7cbf162db786..9e7ed8f4d3129 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1272,7 +1272,18 @@ mlir::Value genMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
LLVM_DEBUG(llvm::dbgs() << "Generating '" << mathLibFuncName
<< "' operation with type ";
mathLibFuncType.dump(); llvm::dbgs() << "\n");
- result = T::create(builder, loc, args);
+ if constexpr (std::is_same_v<T, mlir::complex::PowOp>) {
+ auto resultType = mathLibFuncType.getResult(0);
+ result = T::create(builder, loc, resultType, args);
+ } else if constexpr (std::is_same_v<T, mlir::complex::PowiOp>) {
+ auto resultType = mathLibFuncType.getResult(0);
+ auto fmfAttr = mlir::arith::FastMathFlagsAttr::get(
+ builder.getContext(), builder.getFastMathFlags());
+ result = builder.create<mlir::complex::PowiOp>(loc, resultType, args[0],
+ args[1], fmfAttr);
+ } else {
+ result = T::create(builder, loc, args);
+ }
}
LLVM_DEBUG(result.dump(); llvm::dbgs() << "\n");
return result;
@@ -1323,34 +1334,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
-mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
- const MathOperation &mathOp,
- mlir::FunctionType mathLibFuncType,
- llvm::ArrayRef<mlir::Value> args) {
- if (mathRuntimeVersion == preciseVersion)
- return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
- auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
- mlir::Value exp = args[1];
- mlir::Value result;
- auto fmfAttr = mlir::arith::FastMathFlagsAttr::get(
- builder.getContext(), builder.getFastMathFlags());
- if (mlir::isa<mlir::IntegerType>(exp.getType())) {
- result = builder.create<mlir::complex::PowiOp>(
- loc, mathLibFuncType.getResult(0), args[0], args[1], fmfAttr);
- } else {
- if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, exp);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
- zero);
- }
- result = builder.create<mlir::complex::PowOp>(loc, args[0], exp, fmfAttr);
- }
- result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
- return result;
-}
-
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1676,11 +1659,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowOp>},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
- genComplexPow},
+ genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@@ -1701,20 +1684,20 @@ static constexpr MathOperation mathOperations[] = {
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowi),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
- genComplexPow},
+ genMathOp<mlir::complex::PowiOp>},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 5ad514d0f48e7..5613e021cd709 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
return cutoff4;
}
+struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
+ using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+ auto type = cast<ComplexType>(op.getType());
+ auto elementType = cast<FloatType>(type.getElementType());
+
+ Value floatExponent =
+ builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
+ Value zero = arith::ConstantOp::create(
+ builder, elementType, builder.getFloatAttr(elementType, 0.0));
+ Value complexExponent =
+ complex::CreateOp::create(builder, type, floatExponent, zero);
+
+ auto pow = builder.create<complex::PowOp>(
+ type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
+ rewriter.replaceOp(op, pow.getResult());
+ return success();
+ }
+};
+
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
using OpConversionPattern<complex::PowOp>::OpConversionPattern;
@@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns(
SqrtOpConversion,
TanTanhOpConversion<complex::TanOp>,
TanTanhOpConversion<complex::TanhOp>,
+ PowiOpConversion,
PowOpConversion,
RsqrtOpConversion
>(patterns.getContext());
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a4ddabbd0821a..dec62f92c7b2e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex<f32>,
// -----
+// CHECK-LABEL: func.func @complex_powi
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
+func.func @complex_powi(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
+ %pow = complex.powi %lhs, %rhs : complex<f32>, i32
+ return %pow : complex<f32>
+}
+
+// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
+// CHECK: math.atan2
+// CHECK-NOT: complex.powi
+
+// -----
+
+// CHECK-LABEL: func.func @complex_powi_with_fmf
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
+func.func @complex_powi_with_fmf(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
+ %pow = complex.powi %lhs, %rhs fastmath<nnan,contract> : complex<f32>, i32
+ return %pow : complex<f32>
+}
+
+// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
+// CHECK: math.atan2 {{.*}} fastmath<nnan,contract> : f32
+// CHECK-NOT: complex.powi
+
+// -----
+
// CHECK-LABEL: func.func @complex_rsqrt
func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
%rsqrt = complex.rsqrt %arg : complex<f32>
>From 9902e0850bcd8c81d0715e966cd1e7307538a748 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 18 Sep 2025 16:19:44 +0100
Subject: [PATCH 6/6] Convert both ops in single walk.
---
.../Transforms/ConvertComplexPow.cpp | 111 +++++++++---------
1 file changed, 57 insertions(+), 54 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 1c251883cf707..127f8720ae524 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -51,60 +51,63 @@ void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
- mod.walk([&](complex::PowiOp op) {
- builder.setInsertionPoint(op);
- Location loc = op.getLoc();
- auto complexTy = cast<ComplexType>(op.getType());
- auto elemTy = complexTy.getElementType();
- Value base = op.getLhs();
- Value intExp = op.getRhs();
- func::FuncOp callee;
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
- auto funcTy = builder.getFunctionType(
- {complexTy, builder.getIntegerType(intBits)}, {complexTy});
- if (realBits == 32 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
- else if (realBits == 32 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
- else if (realBits == 64 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
- else if (realBits == 64 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
- else if (realBits == 128 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
- else if (realBits == 128 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
- else
- return;
- auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
- if (auto fmf = op.getFastmathAttr())
- call.setFastmathAttr(fmf);
- op.replaceAllUsesWith(call.getResult(0));
- op.erase();
- });
+ mod.walk([&](Operation *op) {
+ if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
+ builder.setInsertionPoint(powIop);
+ Location loc = powIop.getLoc();
+ auto complexTy = cast<ComplexType>(powIop.getType());
+ auto elemTy = complexTy.getElementType();
+ Value base = powIop.getLhs();
+ Value intExp = powIop.getRhs();
+ func::FuncOp callee;
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+ if (auto fmf = powIop.getFastmathAttr())
+ call.setFastmathAttr(fmf);
+ powIop.replaceAllUsesWith(call.getResult(0));
+ powIop.erase();
+ }
- mod.walk([&](complex::PowOp op) {
- builder.setInsertionPoint(op);
- Location loc = op.getLoc();
- auto complexTy = cast<ComplexType>(op.getType());
- auto elemTy = complexTy.getElementType();
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- func::FuncOp callee;
- auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
- if (realBits == 32)
- callee = getOrDeclare(builder, loc, "cpowf", funcTy);
- else if (realBits == 64)
- callee = getOrDeclare(builder, loc, "cpow", funcTy);
- else if (realBits == 128)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
- else
- return;
- auto call =
- fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
- if (auto fmf = op.getFastmathAttr())
- call.setFastmathAttr(fmf);
- op.replaceAllUsesWith(call.getResult(0));
- op.erase();
+ if (auto powOp = dyn_cast<complex::PowOp>(op)) {
+ builder.setInsertionPoint(powOp);
+ Location loc = powOp.getLoc();
+ auto complexTy = cast<ComplexType>(powOp.getType());
+ auto elemTy = complexTy.getElementType();
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ func::FuncOp callee;
+ auto funcTy =
+ builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ auto call = fir::CallOp::create(builder, loc, callee,
+ {powOp.getLhs(), powOp.getRhs()});
+ if (auto fmf = powOp.getFastmathAttr())
+ call.setFastmathAttr(fmf);
+ powOp.replaceAllUsesWith(call.getResult(0));
+ powOp.erase();
+ }
});
}
More information about the llvm-branch-commits
mailing list