[Mlir-commits] [flang] [mlir] [MLIR] Add ComplexTOROCDL pass (PR #144926)
Akash Banerjee
llvmlistbot at llvm.org
Fri Jun 20 10:31:43 PDT 2025
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/144926
>From ba1230d545d045623c1ab3e202b7e2b836efe02f Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 19 Jun 2025 17:25:48 +0100
Subject: [PATCH 1/4] [MLIR] Add ComplexTOROCDL pass
This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
---
flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 +
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 5 +-
.../ComplexToROCDL/ComplexToROCDL.h | 19 ++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 12 +++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/ComplexToROCDL/CMakeLists.txt | 18 ++++
.../ComplexToROCDL/ComplexToROCDL.cpp | 95 +++++++++++++++++++
.../ComplexToROCDL/complex-to-rocdl.mlir | 13 +++
9 files changed, 164 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
create mode 100644 mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
create mode 100644 mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
+ MLIRComplexToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
- if (isAMDGCN)
+ if (isAMDGCN) {
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+ mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+ }
// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
let dependentDialects = ["func::FuncDialect"];
}
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Complex dialect to ROCDL calls";
+ let description = [{
+ This pass converts supported Complex ops to calls to the AMD device library.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ComplexToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexCommon)
add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToSPIRV)
add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+ ComplexToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRComplexDialect
+ MLIRFuncDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+ std::optional<bool> operator()(Type type) const {
+ auto elementType = cast<FloatType>(type);
+ if (!isa<Float32Type, Float64Type>(elementType))
+ return {};
+ return elementType.getIntOrFloatBitWidth() == 64;
+ }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc) {}
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+ auto module = SymbolTable::getNearestSymbolTable(op);
+ auto isDouble = TypeResolver()(op.getType());
+ if (!isDouble.has_value())
+ return failure();
+
+ auto name = *isDouble ? doubleFunc : floatFunc;
+
+ auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+ SymbolTable::lookupSymbolIn(module, name));
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+ auto funcTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc =
+ rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+ opFunc.setPrivate();
+ }
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+ op->getOperands());
+ return success();
+ }
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+ patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+ : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<func::FuncDialect>();
+ target.addIllegalOp<complex::AbsOp>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+ %rf = complex.abs %f : complex<f32>
+ // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+ %rd = complex.abs %d : complex<f64>
+ // CHECK: return %[[RF]], %[[RD]]
+ return %rf, %rd : f32, f64
+}
>From a10e50ad76963f65d3b84ff152159755a7efdcfe Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 19 Jun 2025 17:38:09 +0100
Subject: [PATCH 2/4] Added license to header.
---
.../mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h | 8 ++++++++
mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp | 1 -
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
index ed65be9980408..96d5a352f54c8 100644
--- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -1,3 +1,11 @@
+//===-- ComplexToROCDL.h - conversion from Complex to ROCDL calls ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
index cdfe2a6dfe874..7c7510b5c4e10 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
-
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
>From bb88939bc53d9e88973749c63ba21dfa26279258 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 20 Jun 2025 16:22:07 +0100
Subject: [PATCH 3/4] Address reviewer changes. Add conversion for complex.exp.
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 14 ++---
.../ComplexToROCDL/ComplexToROCDL.h | 3 +-
.../Conversion/ComplexToROCDL/CMakeLists.txt | 3 -
.../ComplexToROCDL/ComplexToROCDL.cpp | 57 ++++++++++---------
.../ComplexToROCDL/complex-to-rocdl.mlir | 19 ++++++-
5 files changed, 54 insertions(+), 42 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f721b6232b0fb..b8c7cba80d863 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4099,7 +4099,7 @@ class FIRToLLVMLowering
// conversions that affect the ModuleOp, e.g. create new
// function operations in it. We have to run such conversions
// as passes here.
- mlir::OpPassManager mathConvertionPM("builtin.module");
+ mlir::OpPassManager mathConversionPM("builtin.module");
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
// If compiling for AMD target some math operations must be lowered to AMD
@@ -4107,8 +4107,8 @@ class FIRToLLVMLowering
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
if (isAMDGCN) {
- mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
- mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+ mathConversionPM.addPass(mlir::createConvertMathToROCDL());
+ mathConversionPM.addPass(mlir::createConvertComplexToROCDL());
}
// Convert math::FPowI operations to inline implementation
@@ -4116,15 +4116,15 @@ class FIRToLLVMLowering
// it will be lowered to LLVM intrinsic operation by a later conversion.
mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
mathToFuncsOptions.minWidthOfFPowIExponent = 33;
- mathConvertionPM.addPass(
+ mathConversionPM.addPass(
mlir::createConvertMathToFuncs(mathToFuncsOptions));
- mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass());
+ mathConversionPM.addPass(mlir::createConvertComplexToStandardPass());
// Convert Math dialect operations into LLVM dialect operations.
// There is no way to prefer MathToLLVM patterns over MathToLibm
// patterns (applied below), so we have to run MathToLLVM conversion here.
- mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
+ mathConversionPM.addNestedPass<mlir::func::FuncOp>(
mlir::createConvertMathToLLVMPass());
- if (mlir::failed(runPipeline(mathConvertionPM, mod)))
+ if (mlir::failed(runPipeline(mathConversionPM, mod)))
return signalPassFailure();
std::optional<mlir::DataLayout> dl =
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
index 96d5a352f54c8..eb785080adab3 100644
--- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -20,8 +20,7 @@ class RewritePatternSet;
/// Populate the given list with patterns that convert from Complex to ROCDL
/// calls.
-void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit);
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
index 54607250083d7..133809ac32f0f 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -7,9 +7,6 @@ add_mlir_conversion_library(MLIRComplexToROCDL
DEPENDS
MLIRConversionPassIncGen
- LINK_COMPONENTS
- Core
-
LINK_LIBS PUBLIC
MLIRComplexDialect
MLIRFuncDialect
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
index 7c7510b5c4e10..98adb9fb1f607 100644
--- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -11,7 +11,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
@@ -21,36 +20,38 @@ namespace mlir {
using namespace mlir;
namespace {
-struct FloatTypeResolver {
- std::optional<bool> operator()(Type type) const {
- auto elementType = cast<FloatType>(type);
- if (!isa<Float32Type, Float64Type>(elementType))
- return {};
- return elementType.getIntOrFloatBitWidth() == 64;
- }
-};
-template <typename Op, typename TypeResolver = FloatTypeResolver>
-struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+template <typename Op>
+// Pattern to convert Complex ops to ROCDL function calls.
+struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
- ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
- StringRef doubleFunc, PatternBenefit benefit)
+ ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit = 1)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc) {}
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
- auto module = SymbolTable::getNearestSymbolTable(op);
- auto isDouble = TypeResolver()(op.getType());
- if (!isDouble.has_value())
+ Operation *symTable = SymbolTable::getNearestSymbolTable(op);
+ Type resType = op.getType();
+ if (auto complexType = dyn_cast<ComplexType>(resType))
+ resType = complexType.getElementType();
+ FloatType floatTy = dyn_cast<FloatType>(resType);
+ if (!floatTy)
return failure();
- auto name = *isDouble ? doubleFunc : floatFunc;
+ StringRef name;
+ if (floatTy.isF64())
+ name = doubleFunc;
+ else if (floatTy.isF32())
+ name = floatFunc;
+ else
+ return failure();
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
- SymbolTable::lookupSymbolIn(module, name));
+ SymbolTable::lookupSymbolIn(symTable, name));
if (!opFunc) {
OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+ rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
auto funcTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
opFunc =
@@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
};
} // namespace
-void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit) {
- patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
- patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+void mlir::populateComplexToROCDLConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ComplexOpToROCDLCall<complex::AbsOp>>(
+ patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLCall<complex::ExpOp>>(
+ patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64");
}
namespace {
@@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass
} // namespace
void ConvertComplexToROCDLPass::runOnOperation() {
- auto module = getOperation();
+ Operation *op = getOperation();
RewritePatternSet patterns(&getContext());
- populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+ populateComplexToROCDLConversionPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp>();
- if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
index 618e9c238378c..23631e25e4588 100644
--- a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -1,13 +1,26 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+//CHECK-LABEL: @abs_caller
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
- // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+ // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
%rf = complex.abs %f : complex<f32>
- // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+ // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
%rd = complex.abs %d : complex<f64>
// CHECK: return %[[RF]], %[[RD]]
return %rf, %rd : f32, f64
}
+
+//CHECK-LABEL: @exp_caller
+func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
+ %ef = complex.exp %f : complex<f32>
+ // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
+ %ed = complex.exp %d : complex<f64>
+ // CHECK: return %[[EF]], %[[ED]]
+ return %ef, %ed : complex<f32>, complex<f64>
+}
>From d70bca22bae1d66f126b5a1832153810baf06ab4 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 20 Jun 2025 18:30:57 +0100
Subject: [PATCH 4/4] Correct alphabetical order for cmake.
---
flang/lib/Optimizer/CodeGen/CMakeLists.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 8b4ac18fba527..de8e1c5c3fa3f 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -34,13 +34,13 @@ add_flang_library(FIRCodeGen
MLIR_LIBS
MLIRComplexToLLVM
+ MLIRComplexToROCDL
MLIRComplexToStandard
MLIRGPUDialect
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
- MLIRComplexToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
More information about the Mlir-commits
mailing list