[Mlir-commits] [mlir] [mlir][complex] Fastmath flag support for complex.tanh (PR #88571)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 12 12:59:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

<details>
<summary>Changes</summary>

See https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981?u=lewuathe

---
Full diff: https://github.com/llvm/llvm-project/pull/88571.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+8-6) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..9dc146da7ee142 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -945,6 +945,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     // The hyperbolic tangent for complex number can be calculated as follows.
     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -953,17 +954,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
     Value numerator =
         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
     Value one = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+                                                fmf);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..5aec9260867f3f 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2017,3 +2017,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/88571


More information about the Mlir-commits mailing list