[flang-commits] [flang] [flang] Add fastmath attributes to complex arithmetic (PR #70690)
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Tue Oct 31 04:16:55 PDT 2023
https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/70690
>From 08fec5ce5f8697cfdc14b333a137a054802f4c57 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 30 Oct 2023 15:42:17 +0000
Subject: [PATCH 1/3] [flang] add fastmath attributes to FIR complex operations
These attributes (when propagated to LLVM) allow multiple operations to
be merged into one e.g. fused-multiply-add.
I will add support for these attributes in CodeGen in my next patch.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 18 ++++++++++++------
flang/test/Lower/HLFIR/binary-ops.f90 | 6 +++---
flang/test/Lower/OpenACC/acc-reduction.f90 | 4 ++--
flang/test/Lower/array-elemental-calls-2.f90 | 2 +-
flang/test/Lower/assignment.f90 | 6 +++---
5 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dd2e90c3b1a1fde..6e8064a63b7ae0a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2538,12 +2538,18 @@ def fir_NegcOp : ComplexUnaryArithmeticOp<"negc">;
class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_ArithmeticOp<mnemonic, traits>,
- Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs)>;
-
-def fir_AddcOp : ComplexArithmeticOp<"addc", [Commutative]>;
-def fir_SubcOp : ComplexArithmeticOp<"subc">;
-def fir_MulcOp : ComplexArithmeticOp<"mulc", [Commutative]>;
-def fir_DivcOp : ComplexArithmeticOp<"divc">;
+ Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath)>;
+
+def fir_AddcOp : ComplexArithmeticOp<"addc",
+ [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_SubcOp : ComplexArithmeticOp<"subc",
+ [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_MulcOp : ComplexArithmeticOp<"mulc",
+ [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_DivcOp : ComplexArithmeticOp<"divc",
+ [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
// Pow is a builtin call and not a primitive
def fir_CmpcOp : fir_Op<"cmpc",
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 8db6da3de81b291..6b89577cc54581b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -32,7 +32,7 @@ subroutine complex_add(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_sub(x, y, z)
integer :: x, y, z
@@ -65,7 +65,7 @@ subroutine complex_sub(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_mul(x, y, z)
integer :: x, y, z
@@ -98,7 +98,7 @@ subroutine complex_mul(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_div(x, y, z)
integer :: x, y, z
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index b874d5219625df8..8671c280c2fb314 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -163,7 +163,7 @@
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
! CHECK: %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK: %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: }
@@ -183,7 +183,7 @@
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
! CHECK: %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK: %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: }
diff --git a/flang/test/Lower/array-elemental-calls-2.f90 b/flang/test/Lower/array-elemental-calls-2.f90
index 94e24a9910bc267..0d6e34c6391c3df 100644
--- a/flang/test/Lower/array-elemental-calls-2.f90
+++ b/flang/test/Lower/array-elemental-calls-2.f90
@@ -144,7 +144,7 @@ subroutine check_cmplx_part()
! CHECK: %[[VAL_13:.*]] = fir.load %{{.*}} : !fir.ref<!fir.complex<8>>
! CHECK: fir.do_loop
! CHECK: %[[VAL_23:.*]] = fir.array_fetch %{{.*}}, %{{.*}} : (!fir.array<10x!fir.complex<8>>, index) -> !fir.complex<8>
-! CHECK: %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] : !fir.complex<8>
+! CHECK: %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] {fastmath = #arith.fastmath<contract>} : !fir.complex<8>
! CHECK: %[[VAL_25:.*]] = fir.extract_value %[[VAL_24]], [1 : index] : (!fir.complex<8>) -> f64
! CHECK: fir.call @_QPelem_func_real(%[[VAL_25]]) {{.*}}: (f64) -> i32
end subroutine
diff --git a/flang/test/Lower/assignment.f90 b/flang/test/Lower/assignment.f90
index 9b5039e3ea88ebd..058842828d2687a 100644
--- a/flang/test/Lower/assignment.f90
+++ b/flang/test/Lower/assignment.f90
@@ -203,7 +203,7 @@ real function divf(a, b)
! CHECK: %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK: %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[ADD]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: return %[[RET]] : !fir.complex<4>
@@ -219,7 +219,7 @@ real function divf(a, b)
! CHECK: %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[SUB:.*]] = fir.subc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK: %[[SUB:.*]] = fir.subc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[SUB]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: return %[[RET]] : !fir.complex<4>
@@ -235,7 +235,7 @@ real function divf(a, b)
! CHECK: %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[MUL:.*]] = fir.mulc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK: %[[MUL:.*]] = fir.mulc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[MUL]] to %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[RET:.*]] = fir.load %[[FCTRES]] : !fir.ref<!fir.complex<4>>
! CHECK: return %[[RET]] : !fir.complex<4>
>From cd4203a78f1506086dc3b7027dafc15c8b90649c Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 30 Oct 2023 16:05:39 +0000
Subject: [PATCH 2/3] [flang] propagate fir complex fast math flags through
lowering
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 26 +++++++++++++
flang/test/Fir/convert-to-llvm.fir | 50 ++++++++++++-------------
2 files changed, 51 insertions(+), 25 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0f85f89f1a48138..3f6f2b0474d44b4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3502,6 +3503,8 @@ static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
const fir::LLVMTypeConverter &lowering) {
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
@@ -3512,7 +3515,9 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
+ rx.setFastmathFlags(fastmathFlags);
auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
+ ry.setFastmathFlags(fastmathFlags);
auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3560,6 +3565,8 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
// TODO: Can we use a call to __muldc3 ?
// given: (x + iy) * (x' + iy')
// result: (xx'-yy')+i(xy'+yx')
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = mulc.getLoc();
@@ -3570,11 +3577,17 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+ xx.setFastmathFlags(fastmathFlags);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+ yx.setFastmathFlags(fastmathFlags);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+ xy.setFastmathFlags(fastmathFlags);
auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
+ ri.setFastmathFlags(fastmathFlags);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+ yy.setFastmathFlags(fastmathFlags);
auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
+ rr.setFastmathFlags(fastmathFlags);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3594,6 +3607,8 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
// Just generate inline code for now.
// given: (x + iy) / (x' + iy')
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = divc.getLoc();
@@ -3604,16 +3619,27 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+ xx.setFastmathFlags(fastmathFlags);
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+ x1x1.setFastmathFlags(fastmathFlags);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+ yx.setFastmathFlags(fastmathFlags);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+ xy.setFastmathFlags(fastmathFlags);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+ yy.setFastmathFlags(fastmathFlags);
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+ y1y1.setFastmathFlags(fastmathFlags);
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+ d.setFastmathFlags(fastmathFlags);
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+ rrn.setFastmathFlags(fastmathFlags);
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+ rin.setFastmathFlags(fastmathFlags);
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+ rr.setFastmathFlags(fastmathFlags);
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+ ri.setFastmathFlags(fastmathFlags);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cecfbff7eac228b..c9a44914b987053 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -507,7 +507,7 @@ func.func @test_call_return_val() -> i32 {
// result: (x + x') + i(y + y')
func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.addc %a, %b : !fir.complex<16>
+ %c = fir.addc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -518,8 +518,8 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] : f128
-// CHECK: %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] : f128
+// CHECK: %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -532,7 +532,7 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: (x - x') + i(y - y')
func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.subc %a, %b : !fir.complex<16>
+ %c = fir.subc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -543,8 +543,8 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] : f128
-// CHECK: %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] : f128
+// CHECK: %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -557,7 +557,7 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: (xx'-yy')+i(xy'+yx')
func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.mulc %a, %b : !fir.complex<16>
+ %c = fir.mulc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -568,12 +568,12 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
-// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
-// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
-// CHECK: %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] : f128
-// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
-// CHECK: %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
+// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -586,7 +586,7 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.divc %a, %b : !fir.complex<16>
+ %c = fir.divc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -597,17 +597,17 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
-// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128
-// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
-// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128
-// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128
-// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
-// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128
-// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
-// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
+// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
>From e17b64cfade1351d303511c6651b5ddeef6a889d Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Tue, 31 Oct 2023 11:14:39 +0000
Subject: [PATCH 3/3] Construct directly with fast math attr
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 74 ++++++++++---------------
1 file changed, 30 insertions(+), 44 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 3f6f2b0474d44b4..9eabacdc818f6f4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -45,6 +45,7 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/LLVMIR/LLVMAttrs.h>
namespace fir {
#define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -3497,14 +3498,20 @@ struct AbsentOpConversion : public FIROpConversion<fir::AbsentOp> {
// Primitive operations on Complex types
//
+template <typename OPTY>
+static inline mlir::LLVM::FastmathFlagsAttr getLLVMFMFAttr(OPTY op) {
+ return mlir::LLVM::FastmathFlagsAttr::get(
+ op.getContext(),
+ mlir::arith::convertArithFastMathFlagsToLLVM(op.getFastmath()));
+}
+
/// Generate inline code for complex addition/subtraction
template <typename LLVMOP, typename OPTY>
static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
const fir::LLVMTypeConverter &lowering) {
- mlir::LLVM::FastmathFlags fastmathFlags =
- mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
+ mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(sumop);
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
@@ -3514,10 +3521,8 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
- auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
- rx.setFastmathFlags(fastmathFlags);
- auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
- ry.setFastmathFlags(fastmathFlags);
+ auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1, fmf);
+ auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1, fmf);
auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3565,8 +3570,7 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
// TODO: Can we use a call to __muldc3 ?
// given: (x + iy) * (x' + iy')
// result: (xx'-yy')+i(xy'+yx')
- mlir::LLVM::FastmathFlags fastmathFlags =
- mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
+ mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(mulc);
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = mulc.getLoc();
@@ -3576,18 +3580,12 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
- auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
- xx.setFastmathFlags(fastmathFlags);
- auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
- yx.setFastmathFlags(fastmathFlags);
- auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
- xy.setFastmathFlags(fastmathFlags);
- auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
- ri.setFastmathFlags(fastmathFlags);
- auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
- yy.setFastmathFlags(fastmathFlags);
- auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
- rr.setFastmathFlags(fastmathFlags);
+ auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf);
+ auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf);
+ auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf);
+ auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx, fmf);
+ auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf);
+ auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy, fmf);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3607,8 +3605,7 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
// Just generate inline code for now.
// given: (x + iy) / (x' + iy')
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
- mlir::LLVM::FastmathFlags fastmathFlags =
- mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
+ mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(divc);
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = divc.getLoc();
@@ -3618,28 +3615,17 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
- auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
- xx.setFastmathFlags(fastmathFlags);
- auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
- x1x1.setFastmathFlags(fastmathFlags);
- auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
- yx.setFastmathFlags(fastmathFlags);
- auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
- xy.setFastmathFlags(fastmathFlags);
- auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
- yy.setFastmathFlags(fastmathFlags);
- auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
- y1y1.setFastmathFlags(fastmathFlags);
- auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
- d.setFastmathFlags(fastmathFlags);
- auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
- rrn.setFastmathFlags(fastmathFlags);
- auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
- rin.setFastmathFlags(fastmathFlags);
- auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
- rr.setFastmathFlags(fastmathFlags);
- auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
- ri.setFastmathFlags(fastmathFlags);
+ auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf);
+ auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1, fmf);
+ auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf);
+ auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf);
+ auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf);
+ auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1, fmf);
+ auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1, fmf);
+ auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy, fmf);
+ auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy, fmf);
+ auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d, fmf);
+ auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d, fmf);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
More information about the flang-commits
mailing list