[Mlir-commits] [mlir] fc114e4 - [MLIR] Add ComplexTOROCDLLibraryCalls pass (#144926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 16 05:59:44 PDT 2025


Author: Akash Banerjee
Date: 2025-07-16T13:59:41+01:00
New Revision: fc114e4d931ae25f74a15e42371dbead1387ad51

URL: https://github.com/llvm/llvm-project/commit/fc114e4d931ae25f74a15e42371dbead1387ad51
DIFF: https://github.com/llvm/llvm-project/commit/fc114e4d931ae25f74a15e42371dbead1387ad51.diff

LOG: [MLIR] Add ComplexTOROCDLLibraryCalls pass (#144926)

Added: 
    mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
    mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
    mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
    mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Modified: 
    flang/lib/Optimizer/CodeGen/CMakeLists.txt
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..16c7944a885a1 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_library(FIRCodeGen
 
   MLIR_LIBS
   MLIRComplexToLLVM
+  MLIRComplexToROCDLLibraryCalls
   MLIRComplexToStandard
   MLIRGPUDialect
   MLIRMathToFuncs

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ecc04a6c9a2be..5ca53ee48955e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4145,22 +4146,24 @@ class FIRToLLVMLowering
     // conversions that affect the ModuleOp, e.g. create new
     // function operations in it. We have to run such conversions
     // as passes here.
-    mlir::OpPassManager mathConvertionPM("builtin.module");
+    mlir::OpPassManager mathConversionPM("builtin.module");
 
     bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
     // If compiling for AMD target some math operations must be lowered to AMD
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
-      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+    if (isAMDGCN) {
+      mathConversionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
     // it will be lowered to LLVM intrinsic operation by a later conversion.
     mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
     mathToFuncsOptions.minWidthOfFPowIExponent = 33;
-    mathConvertionPM.addPass(
+    mathConversionPM.addPass(
         mlir::createConvertMathToFuncs(mathToFuncsOptions));
 
     mlir::ConvertComplexToStandardPassOptions complexToStandardOptions{};
@@ -4173,15 +4176,15 @@ class FIRToLLVMLowering
       complexToStandardOptions.complexRange =
           mlir::complex::ComplexRangeFlags::improved;
     }
-    mathConvertionPM.addPass(
+    mathConversionPM.addPass(
         mlir::createConvertComplexToStandardPass(complexToStandardOptions));
 
     // Convert Math dialect operations into LLVM dialect operations.
     // There is no way to prefer MathToLLVM patterns over MathToLibm
     // patterns (applied below), so we have to run MathToLLVM conversion here.
-    mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
+    mathConversionPM.addNestedPass<mlir::func::FuncOp>(
         mlir::createConvertMathToLLVMPass());
-    if (mlir::failed(runPipeline(mathConvertionPM, mod)))
+    if (mlir::failed(runPipeline(mathConversionPM, mod)))
       return signalPassFailure();
 
     std::optional<mlir::DataLayout> dl =

diff  --git a/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
new file mode 100644
index 0000000000000..daac2a99ed80f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h
@@ -0,0 +1,27 @@
+//===- ComplexToROCDLLibraryCalls.h - convert from Complex to ROCDL calls -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLLibraryCallsConversionPatterns(
+    RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8a5976e547169..d93fbefab74aa 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..76e751243a12c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDLLibraryCalls
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDLLibraryCalls : Pass<"convert-complex-to-rocdl-library-calls", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL library calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 24a48993ad80c..f84375b6b8d6a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDLLibraryCalls)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)

diff  --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
new file mode 100644
index 0000000000000..695bb2dd0a82c
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDLLibraryCalls
+  ComplexToROCDLLibraryCalls.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDLLibraryCalls
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )

diff  --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
new file mode 100644
index 0000000000000..99d5424aef79a
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -0,0 +1,92 @@
+//=== ComplexToROCDLLibraryCalls.cpp - convert from Complex to ROCDL calls ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op, typename FloatTy>
+// Pattern to convert Complex ops to ROCDL function calls.
+struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ComplexOpToROCDLLibraryCalls(MLIRContext *context, StringRef funcName,
+                               PatternBenefit benefit = 1)
+      : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    Operation *symTable = SymbolTable::getNearestSymbolTable(op);
+    Type resType = op.getType();
+    if (auto complexType = dyn_cast<ComplexType>(resType))
+      resType = complexType.getElementType();
+    if (!isa<FloatTy>(resType))
+      return failure();
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(symTable, funcName));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
+                                             funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string funcName;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cabs_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cabs_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cexp_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cexp_f64");
+}
+
+namespace {
+struct ConvertComplexToROCDLLibraryCallsPass
+    : public impl::ConvertComplexToROCDLLibraryCallsBase<
+          ConvertComplexToROCDLLibraryCallsPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    signalPassFailure();
+}

diff  --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
new file mode 100644
index 0000000000000..bae7c5986ef9e
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+
+//CHECK-LABEL: @abs_caller
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}
+
+//CHECK-LABEL: @exp_caller
+func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
+  %ef = complex.exp %f : complex<f32>
+  // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
+  %ed = complex.exp %d : complex<f64>
+  // CHECK: return %[[EF]], %[[ED]]
+  return %ef, %ed : complex<f32>, complex<f64>
+}


        


More information about the Mlir-commits mailing list