[Mlir-commits] [mlir] [mlir][complex] Lower complex.conj to LLVM (PR #98486)

Guillermo Callaghan llvmlistbot at llvm.org
Thu Jul 11 07:04:37 PDT 2024


https://github.com/Guillermo-Callaghan created https://github.com/llvm/llvm-project/pull/98486

None

>From 19339e2e83f2eaafdc1bc43d1fe23730cf7c0a61 Mon Sep 17 00:00:00 2001
From: Guillermo Callaghan <guillermo.callaghan at huawei.com>
Date: Thu, 11 Jul 2024 21:59:18 +0800
Subject: [PATCH] [mlir][complex] Lower complex.conj to LLVM

---
 .../ComplexToLLVM/ComplexToLLVM.cpp           | 27 +++++++++++++
 .../ComplexToLLVM/convert-to-llvm.mlir        | 39 +++++++++++++++++++
 .../ComplexToLLVM/full-conversion.mlir        | 22 +++++++++++
 3 files changed, 88 insertions(+)

diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 0a3c3a330ff69..db5ef0089e729 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -63,6 +63,32 @@ Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
 
 namespace {
 
+struct ConjOpLowering : public ConvertOpToLLVMPattern<complex::ConjOp> {
+  using ConvertOpToLLVMPattern<complex::ConjOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+
+    ComplexStructBuilder complexStruct(adaptor.getComplex());
+    Value imag = complexStruct.imaginary(rewriter, op.getLoc());
+
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
+
+    Value negImag = rewriter.create<LLVM::FNegOp>(loc, imag, fmf);
+
+    complexStruct.setImaginary(rewriter, loc, negImag);
+
+    rewriter.replaceOp(op, {complexStruct});
+    return success();
+  }
+};
+
+
 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
 
@@ -328,6 +354,7 @@ void mlir::populateComplexToLLVMConversionPatterns(
   patterns.add<
       AbsOpConversion,
       AddOpConversion,
+      ConjOpLowering,
       ConstantOpLowering,
       CreateOpConversion,
       DivOpConversion,
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index a60b974e374d3..76a96cea7d89c 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -1,5 +1,44 @@
 // RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s
 
+// CHECK-LABEL: func @complex_conj_f32
+// CHECK-SAME:    (%[[CPLX:.*]]: complex<f32>)
+// CHECK-NEXT:    %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)>
+// CHECK-DAG:     %[[IMAG_NEG:.*]] = llvm.fneg %[[IMAG]] : f32
+// CHECK-NEXT:    %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG_NEG]], %[[CAST0]][1] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[CPLX2]] : !llvm.struct<(f32, f32)> to complex<f32>
+// CHECK-NEXT:    return %[[CAST2]] : complex<f32>
+func.func @complex_conj_f32(%cplx: complex<f32>) -> complex<f32> {
+  %conj = complex.conj %cplx : complex<f32>
+  return %conj : complex<f32>
+}
+
+// CHECK-LABEL: func @complex_conj_f64
+// CHECK-SAME:    (%[[CPLX:.*]]: complex<f64>)
+// CHECK-NEXT:    %[[CAST0:.*]]    = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f64> to !llvm.struct<(f64, f64)>
+// CHECK-NEXT:    %[[IMAG:.*]]     = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[IMAG_NEG:.*]] = llvm.fneg %[[IMAG]] : f64
+// CHECK-NEXT:    %[[CPLX2:.*]]    = llvm.insertvalue %[[IMAG_NEG]], %[[CAST0]][1] : !llvm.struct<(f64, f64)>
+// CHECK-NEXT:    %[[CAST2:.*]]    = builtin.unrealized_conversion_cast %[[CPLX2]] : !llvm.struct<(f64, f64)> to complex<f64>
+// CHECK-NEXT:    return %[[CAST2]] : complex<f64>
+func.func @complex_conj_f64(%cplx: complex<f64>) -> complex<f64> {
+  %conj = complex.conj %cplx : complex<f64>
+  return %conj : complex<f64>
+}
+
+// CHECK-LABEL: func @complex_conj_with_fmf
+// CHECK-SAME:    (%[[CPLX:.*]]: complex<f32>)
+// CHECK:         %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK-NEXT:    %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)>
+// CHECK-DAG:     %[[IMAG_NEG:.*]] = llvm.fneg %[[IMAG]] {fastmathFlags = #llvm.fastmath<contract>} : f32
+// CHECK-NEXT:    %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG_NEG]], %[[CAST0]][1] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[CPLX2]] : !llvm.struct<(f32, f32)> to complex<f32>
+// CHECK-NEXT:    return %[[CAST2]] : complex<f32>
+func.func @complex_conj_with_fmf(%cplx: complex<f32>) -> complex<f32> {
+  %conj = complex.conj %cplx fastmath<contract> : complex<f32>
+  return %conj : complex<f32>
+}
+
 // Same below, but using the `ConvertToLLVMPatternInterface` entry point
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm="filter-dialects=complex" --split-input-file %s | FileCheck %s
diff --git a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
index b7756b3be543f..bf6bc4665c215 100644
--- a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
@@ -1,5 +1,27 @@
 // RUN: mlir-opt %s -convert-complex-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | FileCheck %s
 
+// CHECK-LABEL: llvm.func @complex_conj_f32
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+// CHECK:      %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY:.*>]]
+// CHECK-DAG:  %[[IMAG_NEG:.*]] = llvm.fneg %[[IMAG]] : f32
+// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG_NEG]], %[[ARG]][1] : ![[C_TY:.*>]]
+// CHECK-NEXT: return %[[CPLX2]] : ![[C_TY:.*>]]
+func.func @complex_conj_f32(%cplx: complex<f32>) -> complex<f32> {
+  %conj = complex.conj %cplx : complex<f32>
+  return %conj : complex<f32>
+}
+
+// CHECK-LABEL: llvm.func @complex_conj_f64
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+// CHECK:      %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY:.*>]]
+// CHECK-DAG:  %[[IMAG_NEG:.*]] = llvm.fneg %[[IMAG]] : f64
+// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG_NEG]], %[[ARG]][1] : ![[C_TY:.*>]]
+// CHECK-NEXT: return %[[CPLX2]] : ![[C_TY:.*>]]
+func.func @complex_conj_f64(%cplx: complex<f64>) -> complex<f64> {
+  %conj = complex.conj %cplx : complex<f64>
+  return %conj : complex<f64>
+}
+
 // CHECK-LABEL: llvm.func @complex_div
 // CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
 func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {



More information about the Mlir-commits mailing list