[flang-commits] [flang] [Flang] Add new ConvertComplexPow pass for Flang (PR #158642)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Mon Sep 15 12:59:28 PDT 2025


================
@@ -0,0 +1,124 @@
+//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/static-multimap-view.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+class ConvertComplexPowPass
+    : public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
+                    arith::ArithDialect, func::FuncDialect>();
+  }
+  void runOnOperation() override;
+};
+} // namespace
+
+// Helper to declare or get a math library function.
+static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
+                                 StringRef name, FunctionType type) {
+  if (auto func = builder.getNamedFunction(name))
+    return func;
+  auto func = builder.createFunction(loc, name, type);
+  func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
+  func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
+                builder.getUnitAttr());
+  return func;
+}
+
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
+    if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
+      return attr.getValue().isZero();
+  return false;
+}
+
+void ConvertComplexPowPass::runOnOperation() {
+  ModuleOp mod = getOperation();
+  if (fir::getTargetTriple(mod).isAMDGCN())
+    return;
+
+  fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
+
+  mod.walk([&](complex::PowOp op) {
+    builder.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto complexTy = cast<ComplexType>(op.getType());
+    auto elemTy = complexTy.getElementType();
+
+    Value base = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Value intExp;
+    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
+      if (isZero(create.getImaginary())) {
+        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
+          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
+            intExp = conv.getValue();
+        }
+      }
+    }
+
+    func::FuncOp callee;
+    SmallVector<Value> args;
+    if (intExp) {
+      unsigned realBits = cast<FloatType>(elemTy).getWidth();
+      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+      auto funcTy = builder.getFunctionType(
+          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+      if (realBits == 32 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+      else if (realBits == 32 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+      else if (realBits == 64 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+      else if (realBits == 64 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+      else if (realBits == 128 && intBits == 32)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+      else if (realBits == 128 && intBits == 64)
+        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+      else
+        return;
+      args = {base, intExp};
+    } else {
+      unsigned realBits = cast<FloatType>(elemTy).getWidth();
----------------
TIFitis wrote:

> Ideally, we should have `powi` and `powf` operations in the `complex` dialect, so that we do not have to rely on the particular `fir.convert`/`complex.create` pattern generated by the lowering. Moreover, SSA values may become block arguments making it harder to recognize the specific pattern even more.

I have added `powi` in #158722. I'll address the rest of this comment along with other comments tomorrow.



https://github.com/llvm/llvm-project/pull/158642


More information about the flang-commits mailing list