[Mlir-commits] [mlir] 8891fd5 - [mlir][complex] Fastmath flag support for complex.tanh (#88571)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 14 03:52:07 PDT 2024


Author: Kai Sasaki
Date: 2024-04-14T19:52:03+09:00
New Revision: 8891fd5acbe441d24a1734aa144f3f3dca075620

URL: https://github.com/llvm/llvm-project/commit/8891fd5acbe441d24a1734aa144f3f3dca075620
DIFF: https://github.com/llvm/llvm-project/commit/8891fd5acbe441d24a1734aa144f3f3dca075620.diff

LOG: [mlir][complex] Fastmath flag support for complex.tanh (#88571)

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0664b053fc9e67..49eb575212ffc1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -973,6 +973,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     // The hyperbolic tangent for complex number can be calculated as follows.
     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -981,17 +982,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
     Value numerator =
         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
     Value one = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+                                                fmf);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index b22c1acacaea18..e0e7cdadd317d2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2085,3 +2085,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file


        


More information about the Mlir-commits mailing list