[Mlir-commits] [mlir] d230bf3 - [mlir][complex] Support Fastmath flag in the conversion of exp, expm1 (#67001)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 22 18:27:45 PDT 2023
Author: Kai Sasaki
Date: 2023-09-23T10:27:42+09:00
New Revision: d230bf3fce6d7fb1d0a1c5ec10b3d2101adb11d6
URL: https://github.com/llvm/llvm-project/commit/d230bf3fce6d7fb1d0a1c5ec10b3d2101adb11d6
DIFF: https://github.com/llvm/llvm-project/commit/d230bf3fce6d7fb1d0a1c5ec10b3d2101adb11d6.diff
LOG: [mlir][complex] Support Fastmath flag in the conversion of exp,expm1 (#67001)
See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 28e490da330f3c3..174b7ce9fed2df4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -446,16 +446,19 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value expReal = rewriter.create<math::ExpOp>(loc, real);
- Value cosImag = rewriter.create<math::CosOp>(loc, imag);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
- Value sinImag = rewriter.create<math::SinOp>(loc, imag);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
+ Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
+ Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+ Value resultReal =
+ rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
+ Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+ Value resultImag =
+ rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -471,14 +474,15 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
ConversionPatternRewriter &rewriter) const override {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
+ Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
Value real = b.create<complex::ReOp>(elementType, exp);
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
- Value realMinusOne = b.create<arith::SubFOp>(real, one);
+ Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
Value imag = b.create<complex::ImOp>(elementType, exp);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9b2eef82541952a..8264382a02651c2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -757,3 +757,44 @@ func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_exp_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %exp = complex.exp %arg fastmath<nnan,contract> : complex<f32>
+ return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func.func @complex_expm1_with_fmf(
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
+ return %expm1 : complex<f32>
+}
+// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
+// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
+// CHECK: return %[[RES]] : complex<f32>
\ No newline at end of file
More information about the Mlir-commits
mailing list