[Mlir-commits] [mlir] [MLIR] Add conversion support for more ops from ComplexToROCDLLibraryCalls (PR #151166)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 08:01:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Akash Banerjee (TIFitis)

<details>
<summary>Changes</summary>

This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.

---
Full diff: https://github.com/llvm/llvm-project/pull/151166.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp (+40-1) 
- (modified) mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir (+108) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc2965e6fd..35ad99c7791db 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
       patterns.getContext(), "__ocml_cabs_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
       patterns.getContext(), "__ocml_cabs_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+      patterns.getContext(), "__ocml_carg_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+      patterns.getContext(), "__ocml_carg_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+      patterns.getContext(), "__ocml_conj_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+      patterns.getContext(), "__ocml_conj_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ccos_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ccos_f64");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
       patterns.getContext(), "__ocml_cexp_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
       patterns.getContext(), "__ocml_cexp_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+      patterns.getContext(), "__ocml_clog_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+      patterns.getContext(), "__ocml_clog_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cpow_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cpow_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+      patterns.getContext(), "__ocml_csin_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+      patterns.getContext(), "__ocml_csin_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+      patterns.getContext(), "__ocml_csqrt_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+      patterns.getContext(), "__ocml_csqrt_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ctan_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ctan_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ctanh_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ctanh_f64");
 }
 
 namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
-  target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+  target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+                      complex::CosOp, complex::ExpOp, complex::LogOp,
+                      complex::PowOp, complex::SinOp, complex::SqrtOp,
+                      complex::TanOp, complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c5986ef9e..ae59f28b46392 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
 // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
 // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
 
 //CHECK-LABEL: @abs_caller
 func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
   return %rf, %rd : f32, f64
 }
 
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+  %af = complex.angle %f : complex<f32>
+  // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+  %ad = complex.angle %d : complex<f64>
+  // CHECK: return %[[AF]], %[[AD]]
+  return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+  %cf = complex.cos %f : complex<f32>
+  // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+  %cd = complex.cos %d : complex<f64>
+  // CHECK: return %[[CF]], %[[CD]]
+  return %cf, %cd : complex<f32>, complex<f64>
+}
+
 //CHECK-LABEL: @exp_caller
 func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
   // CHECK: return %[[EF]], %[[ED]]
   return %ef, %ed : complex<f32>, complex<f64>
 }
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+  %lf = complex.log %f : complex<f32>
+  // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+  %ld = complex.log %d : complex<f64>
+  // CHECK: return %[[LF]], %[[LD]]
+  return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+  %cf2 = complex.conj %f : complex<f32>
+  // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+  %cd2 = complex.conj %d : complex<f64>
+  // CHECK: return %[[CF]], %[[CD]]
+  return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+  %pf = complex.pow %f, %f : complex<f32>
+  // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+  %pd = complex.pow %d, %d : complex<f64>
+  // CHECK: return %[[PF]], %[[PD]]
+  return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+  %sf2 = complex.sin %f : complex<f32>
+  // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+  %sd2 = complex.sin %d : complex<f64>
+  // CHECK: return %[[SF]], %[[SD]]
+  return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+  %sf = complex.sqrt %f : complex<f32>
+  // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+  %sd = complex.sqrt %d : complex<f64>
+  // CHECK: return %[[SF]], %[[SD]]
+  return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+  %tf2 = complex.tan %f : complex<f32>
+  // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+  %td2 = complex.tan %d : complex<f64>
+  // CHECK: return %[[TF]], %[[TD]]
+  return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+  %tf = complex.tanh %f : complex<f32>
+  // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+  %td = complex.tanh %d : complex<f64>
+  // CHECK: return %[[TF]], %[[TD]]
+  return %tf, %td : complex<f32>, complex<f64>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/151166


More information about the Mlir-commits mailing list