[flang-commits] [flang] [mlir] [MLIR] Add ComplexTOROCDL pass (PR #144926)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Wed Jul 2 07:35:41 PDT 2025


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/144926

>From ba1230d545d045623c1ab3e202b7e2b836efe02f Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 19 Jun 2025 17:25:48 +0100
Subject: [PATCH 1/6] [MLIR] Add ComplexTOROCDL pass

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
---
 flang/lib/Optimizer/CodeGen/CMakeLists.txt    |  1 +
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  5 +-
 .../ComplexToROCDL/ComplexToROCDL.h           | 19 ++++
 mlir/include/mlir/Conversion/Passes.h         |  1 +
 mlir/include/mlir/Conversion/Passes.td        | 12 +++
 mlir/lib/Conversion/CMakeLists.txt            |  1 +
 .../Conversion/ComplexToROCDL/CMakeLists.txt  | 18 ++++
 .../ComplexToROCDL/ComplexToROCDL.cpp         | 95 +++++++++++++++++++
 .../ComplexToROCDL/complex-to-rocdl.mlir      | 13 +++
 9 files changed, 164 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
 create mode 100644 mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
 create mode 100644 mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir

diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
+  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
+    if (isAMDGCN) {
       mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+  ComplexToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+  std::optional<bool> operator()(Type type) const {
+    auto elementType = cast<FloatType>(type);
+    if (!isa<Float32Type, Float64Type>(elementType))
+      return {};
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                      StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    auto module = SymbolTable::getNearestSymbolTable(op);
+    auto isDouble = TypeResolver()(op.getType());
+    if (!isDouble.has_value())
+      return failure();
+
+    auto name = *isDouble ? doubleFunc : floatFunc;
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(module, name));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc =
+          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}

>From a10e50ad76963f65d3b84ff152159755a7efdcfe Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 19 Jun 2025 17:38:09 +0100
Subject: [PATCH 2/6] Added license to header.

---
 .../mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h       | 8 ++++++++
 mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp     | 1 -
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
index ed65be9980408..96d5a352f54c8 100644
--- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -1,3 +1,11 @@
+//===-- ComplexToROCDL.h - conversion from Complex to ROCDL calls ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
 #define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
 
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
index cdfe2a6dfe874..7c7510b5c4e10 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
-
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"

>From bb88939bc53d9e88973749c63ba21dfa26279258 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 20 Jun 2025 16:22:07 +0100
Subject: [PATCH 3/6] Address reviewer changes. Add conversion for complex.exp.

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 14 ++---
 .../ComplexToROCDL/ComplexToROCDL.h           |  3 +-
 .../Conversion/ComplexToROCDL/CMakeLists.txt  |  3 -
 .../ComplexToROCDL/ComplexToROCDL.cpp         | 57 ++++++++++---------
 .../ComplexToROCDL/complex-to-rocdl.mlir      | 19 ++++++-
 5 files changed, 54 insertions(+), 42 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f721b6232b0fb..b8c7cba80d863 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4099,7 +4099,7 @@ class FIRToLLVMLowering
     // conversions that affect the ModuleOp, e.g. create new
     // function operations in it. We have to run such conversions
     // as passes here.
-    mlir::OpPassManager mathConvertionPM("builtin.module");
+    mlir::OpPassManager mathConversionPM("builtin.module");
 
     bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
     // If compiling for AMD target some math operations must be lowered to AMD
@@ -4107,8 +4107,8 @@ class FIRToLLVMLowering
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
     if (isAMDGCN) {
-      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
-      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+      mathConversionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConversionPM.addPass(mlir::createConvertComplexToROCDL());
     }
 
     // Convert math::FPowI operations to inline implementation
@@ -4116,15 +4116,15 @@ class FIRToLLVMLowering
     // it will be lowered to LLVM intrinsic operation by a later conversion.
     mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
     mathToFuncsOptions.minWidthOfFPowIExponent = 33;
-    mathConvertionPM.addPass(
+    mathConversionPM.addPass(
         mlir::createConvertMathToFuncs(mathToFuncsOptions));
-    mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass());
+    mathConversionPM.addPass(mlir::createConvertComplexToStandardPass());
     // Convert Math dialect operations into LLVM dialect operations.
     // There is no way to prefer MathToLLVM patterns over MathToLibm
     // patterns (applied below), so we have to run MathToLLVM conversion here.
-    mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
+    mathConversionPM.addNestedPass<mlir::func::FuncOp>(
         mlir::createConvertMathToLLVMPass());
-    if (mlir::failed(runPipeline(mathConvertionPM, mod)))
+    if (mlir::failed(runPipeline(mathConversionPM, mod)))
       return signalPassFailure();
 
     std::optional<mlir::DataLayout> dl =
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
index 96d5a352f54c8..eb785080adab3 100644
--- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -20,8 +20,7 @@ class RewritePatternSet;
 
 /// Populate the given list with patterns that convert from Complex to ROCDL
 /// calls.
-void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit);
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
index 54607250083d7..133809ac32f0f 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -7,9 +7,6 @@ add_mlir_conversion_library(MLIRComplexToROCDL
   DEPENDS
   MLIRConversionPassIncGen
 
-  LINK_COMPONENTS
-  Core
-
   LINK_LIBS PUBLIC
   MLIRComplexDialect
   MLIRFuncDialect
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
index 7c7510b5c4e10..98adb9fb1f607 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -11,7 +11,6 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include <optional>
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
@@ -21,36 +20,38 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-struct FloatTypeResolver {
-  std::optional<bool> operator()(Type type) const {
-    auto elementType = cast<FloatType>(type);
-    if (!isa<Float32Type, Float64Type>(elementType))
-      return {};
-    return elementType.getIntOrFloatBitWidth() == 64;
-  }
-};
 
-template <typename Op, typename TypeResolver = FloatTypeResolver>
-struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+template <typename Op>
+// Pattern to convert Complex ops to ROCDL function calls.
+struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
   using OpRewritePattern<Op>::OpRewritePattern;
-  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
-                      StringRef doubleFunc, PatternBenefit benefit)
+  ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                       StringRef doubleFunc, PatternBenefit benefit = 1)
       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
         doubleFunc(doubleFunc) {}
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
-    auto module = SymbolTable::getNearestSymbolTable(op);
-    auto isDouble = TypeResolver()(op.getType());
-    if (!isDouble.has_value())
+    Operation *symTable = SymbolTable::getNearestSymbolTable(op);
+    Type resType = op.getType();
+    if (auto complexType = dyn_cast<ComplexType>(resType))
+      resType = complexType.getElementType();
+    FloatType floatTy = dyn_cast<FloatType>(resType);
+    if (!floatTy)
       return failure();
 
-    auto name = *isDouble ? doubleFunc : floatFunc;
+    StringRef name;
+    if (floatTy.isF64())
+      name = doubleFunc;
+    else if (floatTy.isF32())
+      name = floatFunc;
+    else
+      return failure();
 
     auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
-        SymbolTable::lookupSymbolIn(module, name));
+        SymbolTable::lookupSymbolIn(symTable, name));
     if (!opFunc) {
       OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
       auto funcTy = FunctionType::get(
           rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
       opFunc =
@@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
 };
 } // namespace
 
-void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
-                                                    PatternBenefit benefit) {
-  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
-      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+void mlir::populateComplexToROCDLConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ComplexOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64");
+  patterns.add<ComplexOpToROCDLCall<complex::ExpOp>>(
+      patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64");
 }
 
 namespace {
@@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass
 } // namespace
 
 void ConvertComplexToROCDLPass::runOnOperation() {
-  auto module = getOperation();
+  Operation *op = getOperation();
 
   RewritePatternSet patterns(&getContext());
-  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+  populateComplexToROCDLConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
-  target.addIllegalOp<complex::AbsOp>();
-  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+  target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
index 618e9c238378c..23631e25e4588 100644
--- a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -1,13 +1,26 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
 
+//CHECK-LABEL: @abs_caller
 func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
-  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
   %rf = complex.abs %f : complex<f32>
-  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
   %rd = complex.abs %d : complex<f64>
   // CHECK: return %[[RF]], %[[RD]]
   return %rf, %rd : f32, f64
 }
+
+//CHECK-LABEL: @exp_caller
+func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
+  %ef = complex.exp %f : complex<f32>
+  // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
+  %ed = complex.exp %d : complex<f64>
+  // CHECK: return %[[EF]], %[[ED]]
+  return %ef, %ed : complex<f32>, complex<f64>
+}

>From d70bca22bae1d66f126b5a1832153810baf06ab4 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 20 Jun 2025 18:30:57 +0100
Subject: [PATCH 4/6] Correct alphabetical order for cmake.

---
 flang/lib/Optimizer/CodeGen/CMakeLists.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 8b4ac18fba527..de8e1c5c3fa3f 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -34,13 +34,13 @@ add_flang_library(FIRCodeGen
 
   MLIR_LIBS
   MLIRComplexToLLVM
+  MLIRComplexToROCDL
   MLIRComplexToStandard
   MLIRGPUDialect
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
-  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation

>From 0e9aef2910fee55664f7bbfd54ec7bbb1ee383f3 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 24 Jun 2025 16:34:11 +0100
Subject: [PATCH 5/6] Add FloatTy as a template parameter.

---
 .../ComplexToROCDL/ComplexToROCDL.cpp         | 42 ++++++++-----------
 1 file changed, 18 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
index 98adb9fb1f607..cfad9f5f6fa19 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -21,59 +21,53 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op>
+template <typename Op, typename Ty>
 // Pattern to convert Complex ops to ROCDL function calls.
 struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
   using OpRewritePattern<Op>::OpRewritePattern;
-  ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
-                       StringRef doubleFunc, PatternBenefit benefit = 1)
-      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
-        doubleFunc(doubleFunc) {}
+  ComplexOpToROCDLCall(MLIRContext *context, StringRef funcName,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
     Operation *symTable = SymbolTable::getNearestSymbolTable(op);
     Type resType = op.getType();
     if (auto complexType = dyn_cast<ComplexType>(resType))
       resType = complexType.getElementType();
-    FloatType floatTy = dyn_cast<FloatType>(resType);
-    if (!floatTy)
-      return failure();
-
-    StringRef name;
-    if (floatTy.isF64())
-      name = doubleFunc;
-    else if (floatTy.isF32())
-      name = floatFunc;
-    else
+    if (!isa<Ty>(resType))
       return failure();
 
     auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
-        SymbolTable::lookupSymbolIn(symTable, name));
+        SymbolTable::lookupSymbolIn(symTable, funcName));
     if (!opFunc) {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
       auto funcTy = FunctionType::get(
           rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
-      opFunc =
-          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
+                                             funcTy);
       opFunc.setPrivate();
     }
-    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
                                               op->getOperands());
     return success();
   }
 
 private:
-  std::string floatFunc, doubleFunc;
+  std::string funcName;
 };
 } // namespace
 
 void mlir::populateComplexToROCDLConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ComplexOpToROCDLCall<complex::AbsOp>>(
-      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64");
-  patterns.add<ComplexOpToROCDLCall<complex::ExpOp>>(
-      patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64");
+  patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cabs_f32");
+  patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cabs_f64");
+  patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cexp_f32");
+  patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cexp_f64");
 }
 
 namespace {

>From 03e439e94408ae4f6dafe0cbc8f8dfad6a80dc54 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Wed, 2 Jul 2025 14:48:34 +0100
Subject: [PATCH 6/6] Rename to ComplexToROCDLLibraryCalls

---
 flang/lib/Optimizer/CodeGen/CMakeLists.txt    |  2 +-
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  4 +--
 .../ComplexToROCDLLibraryCalls.h}             | 14 ++++----
 mlir/include/mlir/Conversion/Passes.h         |  2 +-
 mlir/include/mlir/Conversion/Passes.td        |  6 ++--
 mlir/lib/Conversion/CMakeLists.txt            |  2 +-
 .../Conversion/ComplexToROCDL/CMakeLists.txt  | 15 --------
 .../ComplexToROCDLLibraryCalls/CMakeLists.txt | 18 ++++++++++
 .../ComplexToROCDLLibraryCalls.cpp}           | 35 ++++++++++---------
 .../complex-to-rocdl-library-calls.mlir}      |  2 +-
 10 files changed, 52 insertions(+), 48 deletions(-)
 rename mlir/include/mlir/Conversion/{ComplexToROCDL/ComplexToROCDL.h => ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h} (52%)
 delete mode 100644 mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
 rename mlir/lib/Conversion/{ComplexToROCDL/ComplexToROCDL.cpp => ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp} (68%)
 rename mlir/test/Conversion/{ComplexToROCDL/complex-to-rocdl.mlir => ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir} (92%)

diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index de8e1c5c3fa3f..16c7944a885a1 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -34,7 +34,7 @@ add_flang_library(FIRCodeGen
 
   MLIR_LIBS
   MLIRComplexToLLVM
-  MLIRComplexToROCDL
+  MLIRComplexToROCDLLibraryCalls
   MLIRComplexToStandard
   MLIRGPUDialect
   MLIRMathToFuncs
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index b8c7cba80d863..a536b3f13e997 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,7 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
-#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4108,7 +4108,7 @@ class FIRToLLVMLowering
     // not needed since all math operations are handled this way.
     if (isAMDGCN) {
       mathConversionPM.addPass(mlir::createConvertMathToROCDL());
-      mathConversionPM.addPass(mlir::createConvertComplexToROCDL());
+      mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls());
     }
 
     // Convert math::FPowI operations to inline implementation
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
similarity index 52%
rename from mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
rename to mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
index eb785080adab3..409b87747f437 100644
--- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
+++ b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
@@ -1,4 +1,4 @@
-//===-- ComplexToROCDL.h - conversion from Complex to ROCDL calls ---------===//
+//===- ComplexToROCDLLibraryCalls.h - convert from Complex to ROCDL calls -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,21 +6,21 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
-#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
 
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
 
 namespace mlir {
 class RewritePatternSet;
 
-#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Populate the given list with patterns that convert from Complex to ROCDL
 /// calls.
-void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns);
+void populateComplexToROCDLLibraryCallsConversionPatterns(
+    RewritePatternSet &patterns);
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#endif // MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 67e8f5b99b67b..896fff277679e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,7 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
-#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8ad2341f93a15..b85afb0ae97be 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -313,11 +313,11 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
 }
 
 //===----------------------------------------------------------------------===//
-// ComplexToROCDL
+// ComplexToROCDLLibraryCalls
 //===----------------------------------------------------------------------===//
 
-def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
-  let summary = "Convert Complex dialect to ROCDL calls";
+def ConvertComplexToROCDLLibraryCalls : Pass<"convert-complex-to-rocdl-library-calls", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL library calls";
   let description = [{
     This pass converts supported Complex ops to calls to the AMD device library.
   }];
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 4ad81553a4fa8..b731662dcebc7 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,7 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
-add_subdirectory(ComplexToROCDL)
+add_subdirectory(ComplexToROCDLLibraryCalls)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
deleted file mode 100644
index 133809ac32f0f..0000000000000
--- a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-add_mlir_conversion_library(MLIRComplexToROCDL
-  ComplexToROCDL.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRComplexDialect
-  MLIRFuncDialect
-  MLIRPass
-  MLIRTransformUtils
-  )
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
new file mode 100644
index 0000000000000..695bb2dd0a82c
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDLLibraryCalls
+  ComplexToROCDLLibraryCalls.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDLLibraryCalls
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
similarity index 68%
rename from mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
rename to mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index cfad9f5f6fa19..99d5424aef79a 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -1,4 +1,4 @@
-//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//=== ComplexToROCDLLibraryCalls.cpp - convert from Complex to ROCDL calls ===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,14 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
@@ -21,12 +21,12 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op, typename Ty>
+template <typename Op, typename FloatTy>
 // Pattern to convert Complex ops to ROCDL function calls.
-struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
+struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
   using OpRewritePattern<Op>::OpRewritePattern;
-  ComplexOpToROCDLCall(MLIRContext *context, StringRef funcName,
-                       PatternBenefit benefit = 1)
+  ComplexOpToROCDLLibraryCalls(MLIRContext *context, StringRef funcName,
+                               PatternBenefit benefit = 1)
       : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
@@ -34,7 +34,7 @@ struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
     Type resType = op.getType();
     if (auto complexType = dyn_cast<ComplexType>(resType))
       resType = complexType.getElementType();
-    if (!isa<Ty>(resType))
+    if (!isa<FloatTy>(resType))
       return failure();
 
     auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
@@ -58,30 +58,31 @@ struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
 };
 } // namespace
 
-void mlir::populateComplexToROCDLConversionPatterns(
+void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float32Type>>(
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
-  patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float64Type>>(
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
       patterns.getContext(), "__ocml_cabs_f64");
-  patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float32Type>>(
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
       patterns.getContext(), "__ocml_cexp_f32");
-  patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float64Type>>(
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
       patterns.getContext(), "__ocml_cexp_f64");
 }
 
 namespace {
-struct ConvertComplexToROCDLPass
-    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+struct ConvertComplexToROCDLLibraryCallsPass
+    : public impl::ConvertComplexToROCDLLibraryCallsBase<
+          ConvertComplexToROCDLLibraryCallsPass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void ConvertComplexToROCDLPass::runOnOperation() {
+void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
   Operation *op = getOperation();
 
   RewritePatternSet patterns(&getContext());
-  populateComplexToROCDLConversionPatterns(patterns);
+  populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
similarity index 92%
rename from mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
rename to mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 23631e25e4588..bae7c5986ef9e 100644
--- a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64



More information about the flang-commits mailing list