[Mlir-commits] [mlir] Fastmathflag complex angle (PR #88658)
Kai Sasaki
llvmlistbot at llvm.org
Sun Apr 14 05:44:01 PDT 2024
https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/88658
>From c9ae5536f5cc9b552b60dac6d2c28fcc36f9ccfa Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 12 Apr 2024 13:24:04 +0900
Subject: [PATCH 1/2] [mlir][complex] Fastmath flag support for complex.tanh
---
.../ComplexToStandard/ComplexToStandard.cpp | 14 ++++++++------
.../convert-to-standard.mlir | 19 +++++++++++++++++++
2 files changed, 27 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..9dc146da7ee142 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -945,6 +945,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
// The hyperbolic tangent for complex number can be calculated as follows.
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -953,17 +954,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value tanhA = rewriter.create<math::TanhOp>(loc, real);
- Value cosB = rewriter.create<math::CosOp>(loc, imag);
- Value sinB = rewriter.create<math::SinOp>(loc, imag);
- Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+ Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+ Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+ Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+ Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
Value numerator =
rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
Value one = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 1));
- Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+ Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
- rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+ rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+ fmf);
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..5aec9260867f3f 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2017,3 +2017,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+ return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file
>From b5d3e0b6589d3a1020387730ca9ad359540c4c45 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 12 Apr 2024 13:28:03 +0900
Subject: [PATCH 2/2] [mlir][complex] Fastmath flag support for complex.angle
---
.../ComplexToStandard/ComplexToStandard.cpp | 3 ++-
.../ComplexToStandard/convert-to-standard.mlir | 15 ++++++++++++++-
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9dc146da7ee142..ed266a45294410 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -1143,13 +1143,14 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = op.getType();
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
- rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
+ rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
return success();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5aec9260867f3f..53b1876f033121 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2035,4 +2035,17 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
-// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func.func @complex_angle_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_angle_with_fmf(%arg: complex<f32>) -> f32 {
+ %angle = complex.angle %arg fastmath<nnan,contract> : complex<f32>
+ return %angle : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: return %[[RESULT]] : f32
\ No newline at end of file
More information about the Mlir-commits
mailing list