[Mlir-commits] [mlir] [mlir][complex] Fastmath flag for the trigonometric ops in complex (PR #85563)
Kai Sasaki
llvmlistbot at llvm.org
Sun Mar 17 01:02:43 PDT 2024
https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/85563
Support Fastmath flag to convert trigonometric ops in the complex dialect.
>From 201439e8a4cb455c12703cab69b6419f6af3c495 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sun, 17 Mar 2024 16:51:47 +0900
Subject: [PATCH] [mlir][complex] Fastmath flag for the trigonometric ops in
complex dialect
---
.../ComplexToStandard/ComplexToStandard.cpp | 50 +++++++++++--------
.../convert-to-standard.mlir | 46 +++++++++++++++++
2 files changed, 75 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 76729278ec1b46..17f64f1b65b7c4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
// implementation in the subclass to combine them.
Value half = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
- Value exp = rewriter.create<math::ExpOp>(loc, imag);
- Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
- Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
- Value sin = rewriter.create<math::SinOp>(loc, real);
- Value cos = rewriter.create<math::CosOp>(loc, real);
+ Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
+ Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
+ Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
+ Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
+ Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
auto resultPair =
- combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+ combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
resultPair.second);
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
virtual std::pair<Value, Value>
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const = 0;
+ Value cos, ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const = 0;
};
struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
- std::pair<Value, Value>
- combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const override {
+ std::pair<Value, Value> combine(Location loc, Value scaledExp,
+ Value reciprocalExp, Value sin, Value cos,
+ ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const override {
// Complex cosine is defined as;
// cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
// Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
// We get:
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
- Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
- Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
+ Value sum =
+ rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+ Value diff =
+ rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
return {resultReal, resultImag};
}
};
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
- std::pair<Value, Value>
- combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const override {
+ std::pair<Value, Value> combine(Location loc, Value scaledExp,
+ Value reciprocalExp, Value sin, Value cos,
+ ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const override {
// Complex sine is defined as;
// sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
// Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
// We get:
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
- Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
- Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
+ Value sum =
+ rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+ Value diff =
+ rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
return {resultReal, resultImag};
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5918ff2e0f36c8..bac94aae6b746c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1834,3 +1834,49 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_cos_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %cos = complex.cos %arg fastmath<nnan,contract> : complex<f32>
+ return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @complex_sin_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %cos = complex.sin %arg fastmath<nnan,contract> : complex<f32>
+ return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]
More information about the Mlir-commits
mailing list