[Mlir-commits] [mlir] [mlir][complex] Fastmath flag for complex angle (PR #88658)

Kai Sasaki llvmlistbot at llvm.org
Sun Apr 14 19:15:33 PDT 2024


https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/88658

>From 05053dbdbaf234fa31fa504fc939eeb3be5b6f05 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 12 Apr 2024 13:24:04 +0900
Subject: [PATCH] [mlir][complex] Fastmath flag support for complex.angle

---
 .../ComplexToStandard/ComplexToStandard.cpp   | 17 ++++++----
 .../convert-to-standard.mlir                  | 32 +++++++++++++++++++
 2 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0664b053fc9e67..c92dc35976a546 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -973,6 +973,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     // The hyperbolic tangent for complex number can be calculated as follows.
     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -981,17 +982,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
     Value numerator =
         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
     Value one = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+                                                fmf);
     return success();
   }
 };
@@ -1169,13 +1171,14 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto type = op.getType();
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
 
-    rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
+    rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
 
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index b22c1acacaea18..6d08a1e348895a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2085,3 +2085,35 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_angle_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_angle_with_fmf(%arg: complex<f32>) -> f32 {
+  %angle = complex.angle %arg fastmath<nnan,contract> : complex<f32>
+  return %angle : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: return %[[RESULT]] : f32
\ No newline at end of file



More information about the Mlir-commits mailing list