[Mlir-commits] [mlir] [mlir][complex] Fastmath flag for the trigonometric ops in complex (PR #85563)

Kai Sasaki llvmlistbot at llvm.org
Sun Mar 17 01:02:43 PDT 2024


https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/85563

Support Fastmath flag to convert trigonometric ops in the complex dialect. 

>From 201439e8a4cb455c12703cab69b6419f6af3c495 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sun, 17 Mar 2024 16:51:47 +0900
Subject: [PATCH] [mlir][complex] Fastmath flag for the trigonometric ops in
 complex dialect

---
 .../ComplexToStandard/ComplexToStandard.cpp   | 50 +++++++++++--------
 .../convert-to-standard.mlir                  | 46 +++++++++++++++++
 2 files changed, 75 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 76729278ec1b46..17f64f1b65b7c4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     // implementation in the subclass to combine them.
     Value half = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
-    Value exp = rewriter.create<math::ExpOp>(loc, imag);
-    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
-    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
-    Value sin = rewriter.create<math::SinOp>(loc, real);
-    Value cos = rewriter.create<math::CosOp>(loc, real);
+    Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
+    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
+    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
+    Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
+    Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
 
     auto resultPair =
-        combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+        combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
                                                    resultPair.second);
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
 
   virtual std::pair<Value, Value>
   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const = 0;
+          Value cos, ConversionPatternRewriter &rewriter,
+          arith::FastMathFlagsAttr fmf) const = 0;
 };
 
 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
 
-  std::pair<Value, Value>
-  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const override {
+  std::pair<Value, Value> combine(Location loc, Value scaledExp,
+                                  Value reciprocalExp, Value sin, Value cos,
+                                  ConversionPatternRewriter &rewriter,
+                                  arith::FastMathFlagsAttr fmf) const override {
     // Complex cosine is defined as;
     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
     // Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
     // We get:
     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
-    Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
-    Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
+    Value sum =
+        rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+    Value diff =
+        rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
     return {resultReal, resultImag};
   }
 };
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
 
-  std::pair<Value, Value>
-  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const override {
+  std::pair<Value, Value> combine(Location loc, Value scaledExp,
+                                  Value reciprocalExp, Value sin, Value cos,
+                                  ConversionPatternRewriter &rewriter,
+                                  arith::FastMathFlagsAttr fmf) const override {
     // Complex sine is defined as;
     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
     // Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
     // We get:
     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
-    Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
-    Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
+    Value sum =
+        rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+    Value diff =
+        rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
     return {resultReal, resultImag};
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5918ff2e0f36c8..bac94aae6b746c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1834,3 +1834,49 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
 // CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
 // CHECK: return %[[VAR41]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_cos_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %cos = complex.cos %arg fastmath<nnan,contract> : complex<f32>
+  return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @complex_sin_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %cos = complex.sin %arg fastmath<nnan,contract> : complex<f32>
+  return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]



More information about the Mlir-commits mailing list