[Mlir-commits] [mlir] fc114e4 - [MLIR] Add ComplexTOROCDLLibraryCalls pass (#144926)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 16 05:59:44 PDT 2025
Author: Akash Banerjee
Date: 2025-07-16T13:59:41+01:00
New Revision: fc114e4d931ae25f74a15e42371dbead1387ad51
URL: https://github.com/llvm/llvm-project/commit/fc114e4d931ae25f74a15e42371dbead1387ad51
DIFF: https://github.com/llvm/llvm-project/commit/fc114e4d931ae25f74a15e42371dbead1387ad51.diff
LOG: [MLIR] Add ComplexTOROCDLLibraryCalls pass (#144926)
Added:
mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
Modified:
flang/lib/Optimizer/CodeGen/CMakeLists.txt
flang/lib/Optimizer/CodeGen/CodeGen.cpp
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..16c7944a885a1 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_library(FIRCodeGen
MLIR_LIBS
MLIRComplexToLLVM
+ MLIRComplexToROCDLLibraryCalls
MLIRComplexToStandard
MLIRGPUDialect
MLIRMathToFuncs
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ecc04a6c9a2be..5ca53ee48955e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4145,22 +4146,24 @@ class FIRToLLVMLowering
// conversions that affect the ModuleOp, e.g. create new
// function operations in it. We have to run such conversions
// as passes here.
- mlir::OpPassManager mathConvertionPM("builtin.module");
+ mlir::OpPassManager mathConversionPM("builtin.module");
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
// If compiling for AMD target some math operations must be lowered to AMD
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
- if (isAMDGCN)
- mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+ if (isAMDGCN) {
+ mathConversionPM.addPass(mlir::createConvertMathToROCDL());
+ mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls());
+ }
// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
// it will be lowered to LLVM intrinsic operation by a later conversion.
mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
mathToFuncsOptions.minWidthOfFPowIExponent = 33;
- mathConvertionPM.addPass(
+ mathConversionPM.addPass(
mlir::createConvertMathToFuncs(mathToFuncsOptions));
mlir::ConvertComplexToStandardPassOptions complexToStandardOptions{};
@@ -4173,15 +4176,15 @@ class FIRToLLVMLowering
complexToStandardOptions.complexRange =
mlir::complex::ComplexRangeFlags::improved;
}
- mathConvertionPM.addPass(
+ mathConversionPM.addPass(
mlir::createConvertComplexToStandardPass(complexToStandardOptions));
// Convert Math dialect operations into LLVM dialect operations.
// There is no way to prefer MathToLLVM patterns over MathToLibm
// patterns (applied below), so we have to run MathToLLVM conversion here.
- mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
+ mathConversionPM.addNestedPass<mlir::func::FuncOp>(
mlir::createConvertMathToLLVMPass());
- if (mlir::failed(runPipeline(mathConvertionPM, mod)))
+ if (mlir::failed(runPipeline(mathConversionPM, mod)))
return signalPassFailure();
std::optional<mlir::DataLayout> dl =
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
new file mode 100644
index 0000000000000..daac2a99ed80f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
@@ -0,0 +1,27 @@
+//===- ComplexToROCDLLibraryCalls.h - convert from Complex to ROCDL calls -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLLibraryCallsConversionPatterns(
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8a5976e547169..d93fbefab74aa 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..76e751243a12c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
let dependentDialects = ["func::FuncDialect"];
}
+//===----------------------------------------------------------------------===//
+// ComplexToROCDLLibraryCalls
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDLLibraryCalls : Pass<"convert-complex-to-rocdl-library-calls", "ModuleOp"> {
+ let summary = "Convert Complex dialect to ROCDL library calls";
+ let description = [{
+ This pass converts supported Complex ops to calls to the AMD device library.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ComplexToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 24a48993ad80c..f84375b6b8d6a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexCommon)
add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDLLibraryCalls)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToSPIRV)
add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
new file mode 100644
index 0000000000000..695bb2dd0a82c
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDLLibraryCalls
+ ComplexToROCDLLibraryCalls.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDLLibraryCalls
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRComplexDialect
+ MLIRFuncDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
new file mode 100644
index 0000000000000..99d5424aef79a
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -0,0 +1,92 @@
+//=== ComplexToROCDLLibraryCalls.cpp - convert from Complex to ROCDL calls ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op, typename FloatTy>
+// Pattern to convert Complex ops to ROCDL function calls.
+struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ComplexOpToROCDLLibraryCalls(MLIRContext *context, StringRef funcName,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+ Operation *symTable = SymbolTable::getNearestSymbolTable(op);
+ Type resType = op.getType();
+ if (auto complexType = dyn_cast<ComplexType>(resType))
+ resType = complexType.getElementType();
+ if (!isa<FloatTy>(resType))
+ return failure();
+
+ auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+ SymbolTable::lookupSymbolIn(symTable, funcName));
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
+ auto funcTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
+ funcTy);
+ opFunc.setPrivate();
+ }
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
+ op->getOperands());
+ return success();
+ }
+
+private:
+ std::string funcName;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cabs_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cexp_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cexp_f64");
+}
+
+namespace {
+struct ConvertComplexToROCDLLibraryCallsPass
+ : public impl::ConvertComplexToROCDLLibraryCallsBase<
+ ConvertComplexToROCDLLibraryCallsPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
+ Operation *op = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<func::FuncDialect>();
+ target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
new file mode 100644
index 0000000000000..bae7c5986ef9e
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+
+//CHECK-LABEL: @abs_caller
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
+ %rf = complex.abs %f : complex<f32>
+ // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
+ %rd = complex.abs %d : complex<f64>
+ // CHECK: return %[[RF]], %[[RD]]
+ return %rf, %rd : f32, f64
+}
+
+//CHECK-LABEL: @exp_caller
+func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
+ %ef = complex.exp %f : complex<f32>
+ // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
+ %ed = complex.exp %d : complex<f64>
+ // CHECK: return %[[EF]], %[[ED]]
+ return %ef, %ed : complex<f32>, complex<f64>
+}
More information about the Mlir-commits
mailing list