[flang-commits] [flang] 14de5a2 - [mlir][complex] Initial support for FastMath flag when converting to LLVM

Kai Sasaki via flang-commits flang-commits at lists.llvm.org
Sun Aug 20 18:44:08 PDT 2023


Author: Kai Sasaki
Date: 2023-08-21T10:41:55+09:00
New Revision: 14de5a2a4f27ed719b1f59f99d2730624a26f706

URL: https://github.com/llvm/llvm-project/commit/14de5a2a4f27ed719b1f59f99d2730624a26f706
DIFF: https://github.com/llvm/llvm-project/commit/14de5a2a4f27ed719b1f59f99d2730624a26f706.diff

LOG: [mlir][complex] Initial support for FastMath flag when converting to LLVM

This change contains the initial support of FastMath flag in complex dialect. Similar to what we did in [Arith dialect](https://reviews.llvm.org/rGb56e65d31825fe4a1ae02fdcbad58bb7993d63a7), `fastmath` attributes in the complex dialect are directly mapped to the corresponding LLVM fastmath flags.

In this diff,

- Definition of FastMathAttr as a custom attribute in the Complex dialect that inherits from the EnumAttr class.
- Definition of ComplexFastMathInterface, which is an interface that is implemented by operations that have a complex::fastmath attribute.
- Declaration of a default-valued fastmath attribute for unary and arithmetic operations in the Complex dialect.
- Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags

NOT in this diff (but planned and progressively implemented):

- Documentation of flag meanings
- Support the fastmath flag conversion to Arith dialect
- Folding/rewrite implementations that are enabled by fastmath flags (although it's the original motivation to support the flag)

RFC: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D156310

Added: 
    

Modified: 
    flang/test/Lower/Intrinsics/abs.f90
    flang/test/Lower/Intrinsics/exp.f90
    flang/test/Lower/Intrinsics/log.f90
    flang/test/Lower/complex-operations.f90
    mlir/include/mlir/Dialect/Complex/IR/Complex.h
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
    mlir/lib/Dialect/Complex/IR/CMakeLists.txt
    mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/Intrinsics/abs.f90 b/flang/test/Lower/Intrinsics/abs.f90
index d2288a140ad43f..7986eeee000306 100644
--- a/flang/test/Lower/Intrinsics/abs.f90
+++ b/flang/test/Lower/Intrinsics/abs.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE"
 ! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE"
 ! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST"
-! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST"
+! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX"
 
 ! Test abs intrinsic for various types (int, float, complex)
 
@@ -100,7 +100,9 @@ subroutine abs_testr16(a, b)
 subroutine abs_testzr(a, b)
 ! CMPLX:  %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.complex<4>>
 ! CMPLX-FAST: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<4>) -> complex<f32>
-! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] : complex<f32>
+! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath<contract> : complex<f32>
+! CMPLX-APPROX: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<4>) -> complex<f32>
+! CMPLX-APPROX: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath<contract,afn> : complex<f32>
 ! CMPLX-PRECISE:  %[[VAL_4:.*]] = fir.call @cabsf(%[[VAL_2]]) {{.*}}: (!fir.complex<4>) -> f32
 ! CMPLX:  fir.store %[[VAL_4]] to %[[VAL_1]] : !fir.ref<f32>
 ! CMPLX:  return
@@ -114,7 +116,9 @@ end subroutine abs_testzr
 subroutine abs_testzd(a, b)
 ! CMPLX:  %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.complex<8>>
 ! CMPLX-FAST: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<8>) -> complex<f64>
-! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] : complex<f64>
+! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath<contract> : complex<f64>
+! CMPLX-APPROX: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<8>) -> complex<f64>
+! CMPLX-APPROX: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath<contract,afn> : complex<f64>
 ! CMPLX-PRECISE:  %[[VAL_4:.*]] = fir.call @cabs(%[[VAL_2]]) {{.*}}: (!fir.complex<8>) -> f64
 ! CMPLX:  fir.store %[[VAL_4]] to %[[VAL_1]] : !fir.ref<f64>
 ! CMPLX:  return

diff  --git a/flang/test/Lower/Intrinsics/exp.f90 b/flang/test/Lower/Intrinsics/exp.f90
index f128e548edeb0c..49ed25a5b8e6cb 100644
--- a/flang/test/Lower/Intrinsics/exp.f90
+++ b/flang/test/Lower/Intrinsics/exp.f90
@@ -2,7 +2,7 @@
 ! RUN: bbc -emit-fir --math-runtime=precise -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE"
 ! RUN: bbc -emit-fir --force-mlir-complex -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR"
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE"
-! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-APPROX"
+! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX"
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE"
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR"
 
@@ -61,8 +61,11 @@ subroutine exp_testcd(a, b)
 ! CMPLX-MLIR-LABEL: private @fir.exp.contract.z4.z4
 ! CMPLX-SAME: (%[[ARG32_OUTLINE:.*]]: !fir.complex<4>) -> !fir.complex<4>
 ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex<f32>
-! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] : complex<f32>
+! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] fastmath<contract> : complex<f32>
 ! CMPLX-FAST: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f32>) -> !fir.complex<4>
+! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex<f32>
+! CMPLX-APPROX: %[[E:.*]] = complex.exp %[[C]] fastmath<contract,afn> : complex<f32>
+! CMPLX-APPROX: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f32>) -> !fir.complex<4>
 ! CMPLX-PRECISE: %[[RESULT32_OUTLINE:.*]] = fir.call @cexpf(%[[ARG32_OUTLINE]]) fastmath<contract> : (!fir.complex<4>) -> !fir.complex<4>
 ! CMPLX: return %[[RESULT32_OUTLINE]] : !fir.complex<4>
 
@@ -71,7 +74,10 @@ subroutine exp_testcd(a, b)
 ! CMPLX-MLIR-LABEL: private @fir.exp.contract.z8.z8
 ! CMPLX-SAME: (%[[ARG64_OUTLINE:.*]]: !fir.complex<8>) -> !fir.complex<8>
 ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex<f64>
-! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] : complex<f64>
+! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] fastmath<contract> : complex<f64>
 ! CMPLX-FAST: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f64>) -> !fir.complex<8>
+! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex<f64>
+! CMPLX-APPROX: %[[E:.*]] = complex.exp %[[C]] fastmath<contract,afn> : complex<f64>
+! CMPLX-APPROX: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f64>) -> !fir.complex<8>
 ! CMPLX-PRECISE: %[[RESULT64_OUTLINE:.*]] = fir.call @cexp(%[[ARG64_OUTLINE]]) fastmath<contract> : (!fir.complex<8>) -> !fir.complex<8>
 ! CMPLX: return %[[RESULT64_OUTLINE]] : !fir.complex<8>

diff  --git a/flang/test/Lower/Intrinsics/log.f90 b/flang/test/Lower/Intrinsics/log.f90
index 49be4d968c8908..08dbd4218d64fb 100644
--- a/flang/test/Lower/Intrinsics/log.f90
+++ b/flang/test/Lower/Intrinsics/log.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE"
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE"
 ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR"
-! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-APPROX"
+! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX"
 
 ! CHECK-LABEL: log_testr
 ! CHECK-SAME: (%[[AREF:.*]]: !fir.ref<f32> {{.*}}, %[[BREF:.*]]: !fir.ref<f32> {{.*}})
@@ -81,8 +81,11 @@ subroutine log10_testd(a, b)
 ! CMPLX-MLIR-LABEL: private @fir.log.contract.z4.z4
 ! CMPLX-SAME: (%[[ARG32_OUTLINE:.*]]: !fir.complex<4>) -> !fir.complex<4>
 ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex<f32>
-! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] : complex<f32>
+! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] fastmath<contract> : complex<f32>
 ! CMPLX-FAST: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f32>) -> !fir.complex<4>
+! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex<f32>
+! CMPLX-APPROX: %[[E:.*]] = complex.log %[[C]] fastmath<contract,afn> : complex<f32>
+! CMPLX-APPROX: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f32>) -> !fir.complex<4>
 ! CMPLX-PRECISE: %[[RESULT32_OUTLINE:.*]] = fir.call @clogf(%[[ARG32_OUTLINE]]) fastmath<contract> : (!fir.complex<4>) -> !fir.complex<4>
 ! CMPLX: return %[[RESULT32_OUTLINE]] : !fir.complex<4>
 
@@ -91,8 +94,11 @@ subroutine log10_testd(a, b)
 ! CMPLX-MLIR-LABEL: private @fir.log.contract.z8.z8
 ! CMPLX-SAME: (%[[ARG64_OUTLINE:.*]]: !fir.complex<8>) -> !fir.complex<8>
 ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex<f64>
-! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] : complex<f64>
+! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] fastmath<contract> : complex<f64>
 ! CMPLX-FAST: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f64>) -> !fir.complex<8>
+! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex<f64>
+! CMPLX-APPROX: %[[E:.*]] = complex.log %[[C]] fastmath<contract,afn> : complex<f64>
+! CMPLX-APPROX: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex<f64>) -> !fir.complex<8>
 ! CMPLX-PRECISE: %[[RESULT64_OUTLINE:.*]] = fir.call @clog(%[[ARG64_OUTLINE]]) fastmath<contract> : (!fir.complex<8>) -> !fir.complex<8>
 ! CMPLX: return %[[RESULT64_OUTLINE]] : !fir.complex<8>
 

diff  --git a/flang/test/Lower/complex-operations.f90 b/flang/test/Lower/complex-operations.f90
index c686671c7a1123..42cdac0dc2a21b 100644
--- a/flang/test/Lower/complex-operations.f90
+++ b/flang/test/Lower/complex-operations.f90
@@ -33,7 +33,7 @@ end subroutine mul_test
 ! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<2>>
 ! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<2>) -> complex<f16>
 ! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<2>) -> complex<f16>
-! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex<f16>
+! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] fastmath<contract> : complex<f16>
 ! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex<f16>) -> !fir.complex<2>
 ! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<2>>
 subroutine div_test_half(a,b,c)
@@ -47,7 +47,7 @@ end subroutine div_test_half
 ! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref<!fir.complex<3>>
 ! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<3>) -> complex<bf16>
 ! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<3>) -> complex<bf16>
-! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex<bf16>
+! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] fastmath<contract> : complex<bf16>
 ! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex<bf16>) -> !fir.complex<3>
 ! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref<!fir.complex<3>>
 subroutine div_test_bfloat(a,b,c)

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
index 663e81c71d8605..fb024fa2e951ea 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h
+++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index b80d77996a20f5..a829fa88efa893 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -9,6 +9,8 @@
 #ifndef COMPLEX_OPS
 #define COMPLEX_OPS
 
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Dialect/Complex/IR/ComplexBase.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -22,19 +24,21 @@ class Complex_Op<string mnemonic, list<Trait> traits = []>
 // one result, all of which must be complex numbers of the same type.
 class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
     Complex_Op<mnemonic, traits # [Pure, SameOperandsAndResultType,
-    Elementwise]> {
-  let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
+    Elementwise, DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs, DefaultValuedAttr<
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
   let results = (outs Complex<AnyFloat>:$result);
-  let assemblyFormat = "$lhs `,` $rhs  attr-dict `:` type($result)";
+  let assemblyFormat = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result)";
 }
 
 // Base class for standard unary operations on complex numbers with a
 // floating-point element type. These operations take one operand and return
 // one result; the operand must be a complex number.
 class ComplexUnaryOp<string mnemonic, list<Trait> traits = []> :
-    Complex_Op<mnemonic, traits # [Pure, Elementwise]> {
-  let arguments = (ins Complex<AnyFloat>:$complex);
-  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+    Complex_Op<mnemonic, traits # [Pure, Elementwise, DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let arguments = (ins Complex<AnyFloat>:$complex, DefaultValuedAttr<
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
+  let assemblyFormat = "$complex (`fastmath` `` $fastmath^)? attr-dict `:` type($complex)";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index b219de7fef8f8c..68406125ba5268 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -23,6 +24,7 @@ namespace mlir {
 
 using namespace mlir;
 using namespace mlir::LLVM;
+using namespace mlir::arith;
 
 //===----------------------------------------------------------------------===//
 // ComplexStructBuilder implementation.
@@ -73,7 +75,10 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
     Value real = complexStruct.real(rewriter, op.getLoc());
     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
 
-    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
     Value sqNorm = rewriter.create<LLVM::FAddOp>(
         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
@@ -181,7 +186,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
     Value real =
         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =
@@ -209,7 +217,10 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
     Value rhsRe = arg.rhs.real();
     Value rhsIm = arg.rhs.imag();
     Value lhsRe = arg.lhs.real();
@@ -254,7 +265,10 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
     Value rhsRe = arg.rhs.real();
     Value rhsIm = arg.rhs.imag();
     Value lhsRe = arg.lhs.real();
@@ -291,7 +305,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to substract complex numbers.
-    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
+    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
+    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
+        op.getContext(),
+        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
     Value real =
         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =

diff  --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
index a90f34ec1684d9..3ee0d26f3225f7 100644
--- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRComplexDialect
   MLIRComplexAttributesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRArithAttrToLLVMConversion
   MLIRArithDialect
   MLIRDialect
   MLIRInferTypeOpInterface

diff  --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index 3b8ed25d6073cc..a60b974e374d37 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -154,3 +154,125 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
 // CHECK: return %[[NORM]] : f32
 
+// CHECK-LABEL: func @complex_addition_with_fmf
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath<reassoc>} : f64
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath<reassoc>} : f64
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
+func.func @complex_addition_with_fmf() {
+  %a_re = arith.constant 1.2 : f64
+  %a_im = arith.constant 3.4 : f64
+  %a = complex.create %a_re, %a_im : complex<f64>
+  %b_re = arith.constant 5.6 : f64
+  %b_im = arith.constant 7.8 : f64
+  %b = complex.create %b_re, %b_im : complex<f64>
+  %c = complex.add %a, %b fastmath<reassoc> : complex<f64>
+  return
+}
+
+// CHECK-LABEL: func @complex_substraction_with_fmf
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f64
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f64
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
+func.func @complex_substraction_with_fmf() {
+  %a_re = arith.constant 1.2 : f64
+  %a_im = arith.constant 3.4 : f64
+  %a = complex.create %a_re, %a_im : complex<f64>
+  %b_re = arith.constant 5.6 : f64
+  %b_im = arith.constant 7.8 : f64
+  %b = complex.create %b_re, %b_im : complex<f64>
+  %c = complex.sub %a, %b fastmath<nnan,ninf> : complex<f64>
+  return
+}
+
+// CHECK-LABEL: func @complex_div_with_fmf
+// CHECK-SAME:    %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
+
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+
+// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath<nsz, arcp>} : f32
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
+//
+// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
+func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %div = complex.div %lhs, %rhs fastmath<nsz,arcp> : complex<f32>
+  return %div : complex<f32>
+}
+
+
+// CHECK-LABEL: func @complex_mul_with_fmf
+// CHECK-SAME:    %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
+
+// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
+func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs fastmath<contract,afn>  : complex<f32>
+  return %mul : complex<f32>
+}
+
+// CHECK-LABEL: func @complex_abs_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+// CHECK: %[[CASTED_ARG:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[CASTED_ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[CASTED_ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] {fastmathFlags = #llvm.fastmath<contract>} : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] {fastmathFlags = #llvm.fastmath<contract>} : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] {fastmathFlags = #llvm.fastmath<contract>} : f32
+// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: return %[[NORM]] : f32
+func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
+  %abs = complex.abs %arg fastmath<contract> : complex<f32>
+  return %abs : f32
+}


        


More information about the flang-commits mailing list