[flang-commits] [flang] [mlir] [MLIR] Add new complex.powi op (PR #158722)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Thu Sep 18 18:11:41 PDT 2025


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158722

>From 78d9190314b84b103ff52eb97e857be4507335c7 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Sat, 13 Sep 2025 01:39:27 +0100
Subject: [PATCH] Add complex.powi op.

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 38 +++--------
 .../Transforms/ConvertComplexPow.cpp          | 66 ++++++++-----------
 flang/test/Lower/HLFIR/binary-ops.f90         |  2 +-
 .../test/Lower/Intrinsics/pow_complex16i.f90  |  2 +-
 .../test/Lower/Intrinsics/pow_complex16k.f90  |  2 +-
 flang/test/Lower/amdgcn-complex.f90           |  9 +++
 flang/test/Lower/power-operator.f90           |  9 ++-
 flang/test/Transforms/convert-complex-pow.fir | 60 ++++++++---------
 .../mlir/Dialect/Complex/IR/ComplexOps.td     | 30 +++++++++
 .../ComplexToROCDLLibraryCalls.cpp            | 40 +++++++++--
 .../ComplexToStandard/ComplexToStandard.cpp   | 25 +++++++
 .../Transforms/AlgebraicSimplification.cpp    | 42 +++++++++---
 .../Dialect/Math/Transforms/CMakeLists.txt    |  1 +
 .../complex-to-rocdl-library-calls.mlir       | 14 ++++
 .../convert-to-standard.mlir                  | 30 +++++++++
 mlir/test/Dialect/Complex/powi-simplify.mlir  | 20 ++++++
 16 files changed, 270 insertions(+), 120 deletions(-)
 create mode 100644 mlir/test/Dialect/Complex/powi-simplify.mlir

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..71d35e37bbe94 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
   return result;
 }
 
-mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
-                          const MathOperation &mathOp,
-                          mlir::FunctionType mathLibFuncType,
-                          llvm::ArrayRef<mlir::Value> args) {
-  if (mathRuntimeVersion == preciseVersion)
-    return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
-  auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
-  mlir::Value exp = args[1];
-  if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
-    auto realTy = complexTy.getElementType();
-    mlir::Value realExp = builder.createConvert(loc, realTy, exp);
-    mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
-    exp =
-        builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
-  }
-  mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
-  result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
-  return result;
-}
-
 /// Mapping between mathematical intrinsic operations and MLIR operations
 /// of some appropriate dialect (math, complex, etc.) or libm calls.
 /// TODO: support remaining Fortran math intrinsics.
@@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = {
     {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
     {"pow", "cpowf",
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowOp>},
     {"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowOp>},
     {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
-     genComplexPow},
+     genMathOp<mlir::complex::PowOp>},
     {"pow", RTNAME_STRING(FPow4i),
      genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
      genMathOp<mlir::math::FPowIOp>},
@@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = {
      genMathOp<mlir::math::FPowIOp>},
     {"pow", RTNAME_STRING(cpowi),
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow", RTNAME_STRING(zpowi),
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow", RTNAME_STRING(cpowk),
      genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow", RTNAME_STRING(zpowk),
      genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
-     genComplexPow},
+     genMathOp<mlir::complex::PowiOp>},
     {"pow-unsigned", RTNAME_STRING(UPow1),
      genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
     {"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 78f9d9e4f639a..127f8720ae524 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
   return func;
 }
 
-static bool isZero(Value v) {
-  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
-    if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
-      return attr.getValue().isZero();
-  return false;
-}
-
 void ConvertComplexPowPass::runOnOperation() {
   ModuleOp mod = getOperation();
   fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  mod.walk([&](complex::PowOp op) {
-    builder.setInsertionPoint(op);
-    Location loc = op.getLoc();
-    auto complexTy = cast<ComplexType>(op.getType());
-    auto elemTy = complexTy.getElementType();
-
-    Value base = op.getLhs();
-    Value rhs = op.getRhs();
-
-    Value intExp;
-    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
-      if (isZero(create.getImaginary())) {
-        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
-          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
-            intExp = conv.getValue();
-        }
-      }
-    }
-
-    func::FuncOp callee;
-    SmallVector<Value> args;
-    if (intExp) {
+  mod.walk([&](Operation *op) {
+    if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
+      builder.setInsertionPoint(powIop);
+      Location loc = powIop.getLoc();
+      auto complexTy = cast<ComplexType>(powIop.getType());
+      auto elemTy = complexTy.getElementType();
+      Value base = powIop.getLhs();
+      Value intExp = powIop.getRhs();
+      func::FuncOp callee;
       unsigned realBits = cast<FloatType>(elemTy).getWidth();
       unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
       auto funcTy = builder.getFunctionType(
@@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() {
         callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
       else
         return;
-      args = {base, intExp};
-    } else {
+      auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+      if (auto fmf = powIop.getFastmathAttr())
+        call.setFastmathAttr(fmf);
+      powIop.replaceAllUsesWith(call.getResult(0));
+      powIop.erase();
+    }
+
+    if (auto powOp = dyn_cast<complex::PowOp>(op)) {
+      builder.setInsertionPoint(powOp);
+      Location loc = powOp.getLoc();
+      auto complexTy = cast<ComplexType>(powOp.getType());
+      auto elemTy = complexTy.getElementType();
       unsigned realBits = cast<FloatType>(elemTy).getWidth();
+      func::FuncOp callee;
       auto funcTy =
           builder.getFunctionType({complexTy, complexTy}, {complexTy});
       if (realBits == 32)
@@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() {
         callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
       else
         return;
-      args = {base, rhs};
+      auto call = fir::CallOp::create(builder, loc, callee,
+                                      {powOp.getLhs(), powOp.getRhs()});
+      if (auto fmf = powOp.getFastmathAttr())
+        call.setFastmathAttr(fmf);
+      powOp.replaceAllUsesWith(call.getResult(0));
+      powOp.erase();
     }
-
-    auto call = fir::CallOp::create(builder, loc, callee, args);
-    if (auto fmf = op.getFastmathAttr())
-      call.setFastmathAttr(fmf);
-    op.replaceAllUsesWith(call.getResult(0));
-    op.erase();
   });
 }
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..b7695a761a0b8 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK:  %[[VAL_8:.*]] = complex.pow
+! CHECK:  %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
 
 subroutine extremum(c, n, l)
   integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..ea18d67b75460 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
   complex(16) :: a
   integer(4) :: b
   b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..d2b70185bda9f 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
   complex(16) :: a
   integer(8) :: b
   b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
    complex :: a, b, c
    a = b**c
 end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+   complex :: a, b
+   integer :: i
+   b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
   ! PRECISE: fir.call @_FortranAcpowi
 end subroutine
 
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
   ! PRECISE: fir.call @_FortranAcpowk
 end subroutine
 
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
   ! PRECISE: fir.call @_FortranAzpowi
 end subroutine
 
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
   ! PRECISE: fir.call @_FortranAzpowk
 end subroutine
 
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
   ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
   ! PRECISE: fir.call @cpow
 end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index e09fa7316c4b0..23316ed46d40f 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,51 +2,38 @@
 
 module {
   func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
-    %c0 = arith.constant 0.0 : f32
-    %0 = fir.convert %arg1 : (i32) -> f32
-    %1 = complex.create %0, %c0 : complex<f32>
-    %2 = complex.pow %arg0, %1 : complex<f32>
-    return %2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
+    return %0 : complex<f32>
+  }
+
+  func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+    %0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
+    return %0 : complex<f32>
   }
 
   func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
-    %c0 = arith.constant 0.0 : f32
-    %0 = fir.convert %arg1 : (i64) -> f32
-    %1 = complex.create %0, %c0 : complex<f32>
-    %2 = complex.pow %arg0, %1 : complex<f32>
-    return %2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
+    return %0 : complex<f32>
   }
 
   func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
-    %c0 = arith.constant 0.0 : f64
-    %0 = fir.convert %arg1 : (i32) -> f64
-    %1 = complex.create %0, %c0 : complex<f64>
-    %2 = complex.pow %arg0, %1 : complex<f64>
-    return %2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
+    return %0 : complex<f64>
   }
 
   func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
-    %c0 = arith.constant 0.0 : f64
-    %0 = fir.convert %arg1 : (i64) -> f64
-    %1 = complex.create %0, %c0 : complex<f64>
-    %2 = complex.pow %arg0, %1 : complex<f64>
-    return %2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
+    return %0 : complex<f64>
   }
 
   func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
-    %c0 = arith.constant 0.0 : f128
-    %0 = fir.convert %arg1 : (i32) -> f128
-    %1 = complex.create %0, %c0 : complex<f128>
-    %2 = complex.pow %arg0, %1 : complex<f128>
-    return %2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
+    return %0 : complex<f128>
   }
 
   func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
-    %c0 = arith.constant 0.0 : f128
-    %0 = fir.convert %arg1 : (i64) -> f128
-    %1 = complex.create %0, %c0 : complex<f128>
-    %2 = complex.pow %arg0, %1 : complex<f128>
-    return %2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
+    return %0 : complex<f128>
   }
 
   func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
@@ -74,26 +61,37 @@ module {
 // CHECK-LABEL: func.func @pow_c4_i4(
 // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
+
+// CHECK-LABEL: func.func @pow_c4_i4_fast(
+// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
+// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_i8(
 // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_i4(
 // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_i8(
 // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_i4(
 // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_i8(
 // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
 // CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_fast(
 // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
@@ -108,4 +106,4 @@ module {
 // CHECK-LABEL: func.func @pow_c16_complex(
 // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
 // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
-// CHECK-NOT: complex.pow
\ No newline at end of file
+// CHECK-NOT: complex.pow
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..828379ded14b3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+    [Pure, Elementwise, SameOperandsAndResultShape,
+     AllTypesMatch<["lhs", "result"]>,
+     DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "complex number raised to signed integer power";
+  let description = [{
+    The `powi` operation takes a `base` operand of complex type and a `power`
+    operand of signed integer type and returns one result of the same type
+    as `base`. The result is `base` raised to the power of `power`.
+
+    Example:
+
+    ```mlir
+    %a = complex.powi %b, %c : complex<f32>, i32
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$lhs,
+                       AnySignlessInteger:$rhs,
+                       OptionalAttr<Arith_FastMathAttr>:$fastmath);
+  let results = (outs Complex<AnyFloat>:$result);
+
+  let assemblyFormat =
+      "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
+}
+
 //===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 72b1fa6e833f9..42099aaa6b574 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -74,10 +76,39 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
     return success();
   }
 };
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+  using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowiOp op,
+                                PatternRewriter &rewriter) const final {
+    auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+    Type elementType = complexType.getElementType();
+
+    Type exponentType = op.getRhs().getType();
+    Type exponentFloatType = elementType;
+    if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+      exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+    Location loc = op.getLoc();
+    Value exponentReal =
+        rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+    Value zeroImag = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(exponentFloatType));
+    Value exponent = rewriter.create<complex::CreateOp>(
+        loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+    rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+                                                exponent, op.getFastmathAttr());
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
@@ -128,11 +159,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
   populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<func::FuncDialect>();
-  target.addLegalOp<complex::MulOp>();
+  target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+  target.addLegalOp<complex::CreateOp, complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::PowOp, complex::SinOp,
-                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::PowiOp,
+                      complex::SinOp, complex::SqrtOp, complex::TanOp,
+                      complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 5ad514d0f48e7..5613e021cd709 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   return cutoff4;
 }
 
+struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
+  using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+    auto type = cast<ComplexType>(op.getType());
+    auto elementType = cast<FloatType>(type.getElementType());
+
+    Value floatExponent =
+        builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
+    Value zero = arith::ConstantOp::create(
+        builder, elementType, builder.getFloatAttr(elementType, 0.0));
+    Value complexExponent =
+        complex::CreateOp::create(builder, type, floatExponent, zero);
+
+    auto pow = builder.create<complex::PowOp>(
+        type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
+    rewriter.replaceOp(op, pow.getResult());
+    return success();
+  }
+};
+
 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
   using OpConversionPattern<complex::PowOp>::OpConversionPattern;
 
@@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       SqrtOpConversion,
       TanTanhOpConversion<complex::TanOp>,
       TanTanhOpConversion<complex::TanhOp>,
+      PowiOpConversion,
       PowOpConversion,
       RsqrtOpConversion
   >(patterns.getContext());
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..77b10cec48d8e 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
-  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getFloatAttr(opType, 1.0));
-  else
+  } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+    auto complexTy = cast<ComplexType>(opType);
+    Type elementType = complexTy.getElementType();
+    auto realPart = rewriter.getFloatAttr(elementType, 1.0);
+    auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
+    one = complex::ConstantOp::create(
+        rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
+  } else {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getIntegerAttr(opType, 1));
+  }
 
   // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
@@ -208,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   //       `[fi]powi(x, negative_exponent)`
   //     with:
   //       (1 / x) * (1 / x) * (1 / x) * ...
+  auto buildMul = [&](Value lhs, Value rhs) {
+    if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
+      return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
+                             op.getFastmathAttr());
+    else
+      return MulOpTy::create(rewriter, loc, lhs, rhs);
+  };
   for (unsigned i = 1; i < exponentValue; ++i)
-    result = MulOpTy::create(rewriter, loc, result, base);
+    result = buildMul(result, base);
 
   // Inverse the base for negative exponent, i.e. for
   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
-  if (exponentIsNegative)
-    result = DivOpTy::create(rewriter, loc, bcast(one), result);
+  if (exponentIsNegative) {
+    if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
+      result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
+                               op.getFastmathAttr());
+    else
+      result = DivOpTy::create(rewriter, loc, bcast(one), result);
+  }
 
   rewriter.replaceOp(op, result);
   return success();
@@ -224,9 +245,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
 
 void mlir::populateMathAlgebraicSimplificationPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<PowFStrengthReduction,
-           PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
-           PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
-          patterns.getContext());
+  patterns.add<
+      PowFStrengthReduction,
+      PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
+      PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
+      PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
+      patterns.getContext(), /*exponentThreshold=*/8);
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index d37a056e8e158..ff62b515533c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRComplexDialect
   MLIRDialectUtils
   MLIRIR
   MLIRMathDialect
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 080ba4f0ff67b..cf177528e532c 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
   return %r : complex<f32>
 }
 
+//CHECK-LABEL: @powi_caller
+//CHECK:          (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32)
+func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> {
+  // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32
+  // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32>
+  // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32>
+  // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32>
+  // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32>
+  // CHECK: return %[[EXP]] : complex<f32>
+  %r = complex.powi %z, %n : complex<f32>, i32
+  return %r : complex<f32>
+}
+
 //CHECK-LABEL: @sin_caller
 func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a4ddabbd0821a..dec62f92c7b2e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex<f32>,
 
 // -----
 
+// CHECK-LABEL: func.func @complex_powi
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
+func.func @complex_powi(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
+  %pow = complex.powi %lhs, %rhs : complex<f32>, i32
+  return %pow : complex<f32>
+}
+
+// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
+// CHECK: math.atan2
+// CHECK-NOT: complex.powi
+
+// -----
+
+// CHECK-LABEL: func.func @complex_powi_with_fmf
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
+func.func @complex_powi_with_fmf(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
+  %pow = complex.powi %lhs, %rhs fastmath<nnan,contract> : complex<f32>, i32
+  return %pow : complex<f32>
+}
+
+// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
+// CHECK: math.atan2 {{.*}} fastmath<nnan,contract> : f32
+// CHECK-NOT: complex.powi
+
+// -----
+
 // CHECK-LABEL:   func.func @complex_rsqrt
 func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
   %rsqrt = complex.rsqrt %arg : complex<f32>
diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir b/mlir/test/Dialect/Complex/powi-simplify.mlir
new file mode 100644
index 0000000000000..c7bb6a9d81479
--- /dev/null
+++ b/mlir/test/Dialect/Complex/powi-simplify.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s
+
+func.func @pow3(%arg0: complex<f32>) -> complex<f32> {
+  %c3 = arith.constant 3 : i32
+  %0 = complex.powi %arg0, %c3 : complex<f32>, i32
+  return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow3(
+// CHECK-NOT: complex.powi
+// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32>
+// CHECK: return %[[M1]] : complex<f32>
+
+func.func @pow9(%arg0: complex<f32>) -> complex<f32> {
+  %c9 = arith.constant 9 : i32
+  %0 = complex.powi %arg0, %c9 : complex<f32>, i32
+  return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow9(
+// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32



More information about the flang-commits mailing list