[flang-commits] [flang] [flang] Add fastmath attributes to complex arithmetic (PR #70690)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Tue Oct 31 04:16:55 PDT 2023


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/70690

>From 08fec5ce5f8697cfdc14b333a137a054802f4c57 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 30 Oct 2023 15:42:17 +0000
Subject: [PATCH 1/3] [flang] add fastmath attributes to FIR complex operations

These attributes (when propagated to LLVM) allow multiple operations to
be merged into one e.g. fused-multiply-add.

I will add support for these attributes in CodeGen in my next patch.
---
 .../include/flang/Optimizer/Dialect/FIROps.td  | 18 ++++++++++++------
 flang/test/Lower/HLFIR/binary-ops.f90          |  6 +++---
 flang/test/Lower/OpenACC/acc-reduction.f90     |  4 ++--
 flang/test/Lower/array-elemental-calls-2.f90   |  2 +-
 flang/test/Lower/assignment.f90                |  6 +++---
 5 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dd2e90c3b1a1fde..6e8064a63b7ae0a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2538,12 +2538,18 @@ def fir_NegcOp : ComplexUnaryArithmeticOp<"negc">;
 
 class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
       fir_ArithmeticOp<mnemonic, traits>,
-      Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs)>;
-
-def fir_AddcOp : ComplexArithmeticOp<"addc", [Commutative]>;
-def fir_SubcOp : ComplexArithmeticOp<"subc">;
-def fir_MulcOp : ComplexArithmeticOp<"mulc", [Commutative]>;
-def fir_DivcOp : ComplexArithmeticOp<"divc">;
+      Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs,
+          DefaultValuedAttr<Arith_FastMathAttr,
+                            "::mlir::arith::FastMathFlags::none">:$fastmath)>;
+
+def fir_AddcOp : ComplexArithmeticOp<"addc",
+    [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_SubcOp : ComplexArithmeticOp<"subc",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_MulcOp : ComplexArithmeticOp<"mulc",
+    [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_DivcOp : ComplexArithmeticOp<"divc",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
 // Pow is a builtin call and not a primitive
 
 def fir_CmpcOp : fir_Op<"cmpc",
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 8db6da3de81b291..6b89577cc54581b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -32,7 +32,7 @@ subroutine complex_add(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_sub(x, y, z)
  integer :: x, y, z
@@ -65,7 +65,7 @@ subroutine complex_sub(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_mul(x, y, z)
  integer :: x, y, z
@@ -98,7 +98,7 @@ subroutine complex_mul(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_div(x, y, z)
  integer :: x, y, z
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index b874d5219625df8..8671c280c2fb314 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -163,7 +163,7 @@
 ! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
 ! CHECK:   %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK:   %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK:   %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:   fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK: }
@@ -183,7 +183,7 @@
 ! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
 ! CHECK:   %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK:   %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK:   %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:   fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK: }
diff --git a/flang/test/Lower/array-elemental-calls-2.f90 b/flang/test/Lower/array-elemental-calls-2.f90
index 94e24a9910bc267..0d6e34c6391c3df 100644
--- a/flang/test/Lower/array-elemental-calls-2.f90
+++ b/flang/test/Lower/array-elemental-calls-2.f90
@@ -144,7 +144,7 @@ subroutine check_cmplx_part()
 ! CHECK:  %[[VAL_13:.*]] = fir.load %{{.*}} : !fir.ref<!fir.complex<8>>
 ! CHECK:  fir.do_loop
 ! CHECK:  %[[VAL_23:.*]] = fir.array_fetch %{{.*}}, %{{.*}} : (!fir.array<10x!fir.complex<8>>, index) -> !fir.complex<8>
-! CHECK:  %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] : !fir.complex<8>
+! CHECK:  %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] {fastmath = #arith.fastmath<contract>} : !fir.complex<8>
 ! CHECK:  %[[VAL_25:.*]] = fir.extract_value %[[VAL_24]], [1 : index] : (!fir.complex<8>) -> f64
 ! CHECK:  fir.call @_QPelem_func_real(%[[VAL_25]]) {{.*}}: (f64) -> i32
 end subroutine
diff --git a/flang/test/Lower/assignment.f90 b/flang/test/Lower/assignment.f90
index 9b5039e3ea88ebd..058842828d2687a 100644
--- a/flang/test/Lower/assignment.f90
+++ b/flang/test/Lower/assignment.f90
@@ -203,7 +203,7 @@ real function divf(a, b)
 ! CHECK:         %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
 ! CHECK:         %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK:         %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK:         %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:         fir.store %[[ADD]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         return %[[RET]] : !fir.complex<4>
@@ -219,7 +219,7 @@ real function divf(a, b)
 ! CHECK:         %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
 ! CHECK:         %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK:         %[[SUB:.*]] = fir.subc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK:         %[[SUB:.*]] = fir.subc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:         fir.store %[[SUB]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         return %[[RET]] : !fir.complex<4>
@@ -235,7 +235,7 @@ real function divf(a, b)
 ! CHECK:         %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
 ! CHECK:         %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK:         %[[MUL:.*]] = fir.mulc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK:         %[[MUL:.*]] = fir.mulc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:         fir.store %[[MUL]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         return %[[RET]] : !fir.complex<4>

>From cd4203a78f1506086dc3b7027dafc15c8b90649c Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 30 Oct 2023 16:05:39 +0000
Subject: [PATCH 2/3] [flang] propagate fir complex fast math flags through
 lowering

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 26 +++++++++++++
 flang/test/Fir/convert-to-llvm.fir      | 50 ++++++++++++-------------
 2 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0f85f89f1a48138..3f6f2b0474d44b4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3502,6 +3503,8 @@ static mlir::LLVM::InsertValueOp
 complexSum(OPTY sumop, mlir::ValueRange opnds,
            mlir::ConversionPatternRewriter &rewriter,
            const fir::LLVMTypeConverter &lowering) {
+  mlir::LLVM::FastmathFlags fastmathFlags =
+      mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
   mlir::Value a = opnds[0];
   mlir::Value b = opnds[1];
   auto loc = sumop.getLoc();
@@ -3512,7 +3515,9 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
   auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
   auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
   auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
+  rx.setFastmathFlags(fastmathFlags);
   auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
+  ry.setFastmathFlags(fastmathFlags);
   auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
   auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
   return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3560,6 +3565,8 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     // TODO: Can we use a call to __muldc3 ?
     // given: (x + iy) * (x' + iy')
     // result: (xx'-yy')+i(xy'+yx')
+    mlir::LLVM::FastmathFlags fastmathFlags =
+        mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = mulc.getLoc();
@@ -3570,11 +3577,17 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+    xx.setFastmathFlags(fastmathFlags);
     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+    yx.setFastmathFlags(fastmathFlags);
     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+    xy.setFastmathFlags(fastmathFlags);
     auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
+    ri.setFastmathFlags(fastmathFlags);
     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+    yy.setFastmathFlags(fastmathFlags);
     auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
+    rr.setFastmathFlags(fastmathFlags);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3594,6 +3607,8 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     // Just generate inline code for now.
     // given: (x + iy) / (x' + iy')
     // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
+    mlir::LLVM::FastmathFlags fastmathFlags =
+        mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = divc.getLoc();
@@ -3604,16 +3619,27 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+    xx.setFastmathFlags(fastmathFlags);
     auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+    x1x1.setFastmathFlags(fastmathFlags);
     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+    yx.setFastmathFlags(fastmathFlags);
     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+    xy.setFastmathFlags(fastmathFlags);
     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+    yy.setFastmathFlags(fastmathFlags);
     auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+    y1y1.setFastmathFlags(fastmathFlags);
     auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+    d.setFastmathFlags(fastmathFlags);
     auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+    rrn.setFastmathFlags(fastmathFlags);
     auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+    rin.setFastmathFlags(fastmathFlags);
     auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+    rr.setFastmathFlags(fastmathFlags);
     auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+    ri.setFastmathFlags(fastmathFlags);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cecfbff7eac228b..c9a44914b987053 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -507,7 +507,7 @@ func.func @test_call_return_val() -> i32 {
 // result: (x + x') + i(y + y')
 
 func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.addc %a, %b : !fir.complex<16>
+  %c = fir.addc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -518,8 +518,8 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]]  : f128
+// CHECK:         %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -532,7 +532,7 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: (x - x') + i(y - y')
 
 func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.subc %a, %b : !fir.complex<16>
+  %c = fir.subc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -543,8 +543,8 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]]  : f128
+// CHECK:         %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>}  : f128
+// CHECK:         %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -557,7 +557,7 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: (xx'-yy')+i(xy'+yx')
 
 func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.mulc %a, %b : !fir.complex<16>
+  %c = fir.mulc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -568,12 +568,12 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : f128
-// CHECK:         %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]]  : f128
-// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : f128
-// CHECK:         %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : f128
+// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -586,7 +586,7 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
 
 func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.divc %a, %b : !fir.complex<16>
+  %c = fir.divc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -597,17 +597,17 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]]  : f128
-// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]]  : f128
-// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]]  : f128
-// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : f128
-// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]]  : f128
-// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
-// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
+// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>

>From e17b64cfade1351d303511c6651b5ddeef6a889d Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Tue, 31 Oct 2023 11:14:39 +0000
Subject: [PATCH 3/3] Construct directly with fast math attr

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 74 ++++++++++---------------
 1 file changed, 30 insertions(+), 44 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 3f6f2b0474d44b4..9eabacdc818f6f4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -45,6 +45,7 @@
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/LLVMIR/LLVMAttrs.h>
 
 namespace fir {
 #define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -3497,14 +3498,20 @@ struct AbsentOpConversion : public FIROpConversion<fir::AbsentOp> {
 // Primitive operations on Complex types
 //
 
+template <typename OPTY>
+static inline mlir::LLVM::FastmathFlagsAttr getLLVMFMFAttr(OPTY op) {
+  return mlir::LLVM::FastmathFlagsAttr::get(
+      op.getContext(),
+      mlir::arith::convertArithFastMathFlagsToLLVM(op.getFastmath()));
+}
+
 /// Generate inline code for complex addition/subtraction
 template <typename LLVMOP, typename OPTY>
 static mlir::LLVM::InsertValueOp
 complexSum(OPTY sumop, mlir::ValueRange opnds,
            mlir::ConversionPatternRewriter &rewriter,
            const fir::LLVMTypeConverter &lowering) {
-  mlir::LLVM::FastmathFlags fastmathFlags =
-      mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
+  mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(sumop);
   mlir::Value a = opnds[0];
   mlir::Value b = opnds[1];
   auto loc = sumop.getLoc();
@@ -3514,10 +3521,8 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
   auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
   auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
   auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
-  auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
-  rx.setFastmathFlags(fastmathFlags);
-  auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
-  ry.setFastmathFlags(fastmathFlags);
+  auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1, fmf);
+  auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1, fmf);
   auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
   auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
   return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3565,8 +3570,7 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     // TODO: Can we use a call to __muldc3 ?
     // given: (x + iy) * (x' + iy')
     // result: (xx'-yy')+i(xy'+yx')
-    mlir::LLVM::FastmathFlags fastmathFlags =
-        mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
+    mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(mulc);
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = mulc.getLoc();
@@ -3576,18 +3580,12 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
-    auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
-    xx.setFastmathFlags(fastmathFlags);
-    auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
-    yx.setFastmathFlags(fastmathFlags);
-    auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
-    xy.setFastmathFlags(fastmathFlags);
-    auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
-    ri.setFastmathFlags(fastmathFlags);
-    auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
-    yy.setFastmathFlags(fastmathFlags);
-    auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
-    rr.setFastmathFlags(fastmathFlags);
+    auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf);
+    auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf);
+    auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf);
+    auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx, fmf);
+    auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf);
+    auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy, fmf);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3607,8 +3605,7 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     // Just generate inline code for now.
     // given: (x + iy) / (x' + iy')
     // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
-    mlir::LLVM::FastmathFlags fastmathFlags =
-        mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
+    mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(divc);
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = divc.getLoc();
@@ -3618,28 +3615,17 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
-    auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
-    xx.setFastmathFlags(fastmathFlags);
-    auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
-    x1x1.setFastmathFlags(fastmathFlags);
-    auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
-    yx.setFastmathFlags(fastmathFlags);
-    auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
-    xy.setFastmathFlags(fastmathFlags);
-    auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
-    yy.setFastmathFlags(fastmathFlags);
-    auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
-    y1y1.setFastmathFlags(fastmathFlags);
-    auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
-    d.setFastmathFlags(fastmathFlags);
-    auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
-    rrn.setFastmathFlags(fastmathFlags);
-    auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
-    rin.setFastmathFlags(fastmathFlags);
-    auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
-    rr.setFastmathFlags(fastmathFlags);
-    auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
-    ri.setFastmathFlags(fastmathFlags);
+    auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf);
+    auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1, fmf);
+    auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf);
+    auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf);
+    auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf);
+    auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1, fmf);
+    auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1, fmf);
+    auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy, fmf);
+    auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy, fmf);
+    auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d, fmf);
+    auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d, fmf);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);



More information about the flang-commits mailing list