[Mlir-commits] [mlir] [mlir][complex] Fastmath flag for the trigonometric ops in complex (PR #85563)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 17 01:03:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

<details>
<summary>Changes</summary>

Support Fastmath flag to convert trigonometric ops in the complex dialect. 

---
Full diff: https://github.com/llvm/llvm-project/pull/85563.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+29-21) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+46) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 76729278ec1b46..17f64f1b65b7c4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     // implementation in the subclass to combine them.
     Value half = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
-    Value exp = rewriter.create<math::ExpOp>(loc, imag);
-    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
-    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
-    Value sin = rewriter.create<math::SinOp>(loc, real);
-    Value cos = rewriter.create<math::CosOp>(loc, real);
+    Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
+    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
+    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
+    Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
+    Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
 
     auto resultPair =
-        combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+        combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
                                                    resultPair.second);
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
 
   virtual std::pair<Value, Value>
   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const = 0;
+          Value cos, ConversionPatternRewriter &rewriter,
+          arith::FastMathFlagsAttr fmf) const = 0;
 };
 
 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
 
-  std::pair<Value, Value>
-  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const override {
+  std::pair<Value, Value> combine(Location loc, Value scaledExp,
+                                  Value reciprocalExp, Value sin, Value cos,
+                                  ConversionPatternRewriter &rewriter,
+                                  arith::FastMathFlagsAttr fmf) const override {
     // Complex cosine is defined as;
     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
     // Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
     // We get:
     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
-    Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
-    Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
+    Value sum =
+        rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+    Value diff =
+        rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
     return {resultReal, resultImag};
   }
 };
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
 
-  std::pair<Value, Value>
-  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
-          Value cos, ConversionPatternRewriter &rewriter) const override {
+  std::pair<Value, Value> combine(Location loc, Value scaledExp,
+                                  Value reciprocalExp, Value sin, Value cos,
+                                  ConversionPatternRewriter &rewriter,
+                                  arith::FastMathFlagsAttr fmf) const override {
     // Complex sine is defined as;
     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
     // Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
     // We get:
     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
-    Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
-    Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
+    Value sum =
+        rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+    Value diff =
+        rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
     return {resultReal, resultImag};
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5918ff2e0f36c8..bac94aae6b746c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1834,3 +1834,49 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
 // CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
 // CHECK: return %[[VAR41]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_cos_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %cos = complex.cos %arg fastmath<nnan,contract> : complex<f32>
+  return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @complex_sin_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %cos = complex.sin %arg fastmath<nnan,contract> : complex<f32>
+  return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85563


More information about the Mlir-commits mailing list