[Mlir-commits] [mlir] [mlir][complex] Support fastmath in the binary op conversion. (PR #65702)

Kai Sasaki llvmlistbot at llvm.org
Thu Sep 7 17:43:42 PDT 2023


https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/65702:

Complex dialect arithmetic operations are now able to recognize the given fastmath flags. This PR lets the conversion from complex to standard keep the fastmath flag passed to arith dialect ops.

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

>From c471f8fb90b7493f9da718d235eca8ada8498ccb Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 8 Sep 2023 09:41:01 +0900
Subject: [PATCH] [mlir][complex] Support fastmath in the binary op conversion.

Complex dialect arithmetic operations are now able to recognize the
given fastmath flags. This PR lets the conversion from complex to
standard keep the fastmath flag passed to arith dialect ops.

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
---
 .../ComplexToStandard/ComplexToStandard.cpp   |  9 ++---
 .../convert-to-standard.mlir                  | 34 +++++++++++++++++++
 2 files changed, 39 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 2bcec4ea10f92c5..28e490da330f3c3 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
-    Value resultReal =
-        b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
+    Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
+                                                  fmf.getValue());
     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
-    Value resultImag =
-        b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
+    Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
+                                                  fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index bc2ea0dd7a5847a..9b2eef82541952a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
 // CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
+
+// -----
+
+// CHECK-LABEL: func @complex_add_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_add_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %add = complex.add %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %add : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sub_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %sub = complex.sub %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %sub : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>



More information about the Mlir-commits mailing list