[Mlir-commits] [mlir] [mlir][complex] Support Fastmath flag in the conversion of exp, expm1 (PR #67001)

Kai Sasaki llvmlistbot at llvm.org
Thu Sep 21 04:00:17 PDT 2023


https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/67001

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

>From 81cad95937820817f83ba1c356d1cadd49ecbe57 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Tue, 19 Sep 2023 11:28:02 +0900
Subject: [PATCH] [mlir][complex] Support Fastmath flag in the conversion of
 exp,expm1

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
---
 .../ComplexToStandard/ComplexToStandard.cpp   | 18 ++++----
 .../convert-to-standard.mlir                  | 41 +++++++++++++++++++
 2 files changed, 52 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 28e490da330f3c3..174b7ce9fed2df4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -446,16 +446,19 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value expReal = rewriter.create<math::ExpOp>(loc, real);
-    Value cosImag = rewriter.create<math::CosOp>(loc, imag);
-    Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
-    Value sinImag = rewriter.create<math::SinOp>(loc, imag);
-    Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
+    Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
+    Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+    Value resultReal =
+        rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
+    Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+    Value resultImag =
+        rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
@@ -471,14 +474,15 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
+    Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
 
     Value real = b.create<complex::ReOp>(elementType, exp);
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1));
-    Value realMinusOne = b.create<arith::SubFOp>(real, one);
+    Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
     Value imag = b.create<complex::ImOp>(elementType, exp);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9b2eef82541952a..8264382a02651c2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -757,3 +757,44 @@ func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_exp_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %exp = complex.exp %arg fastmath<nnan,contract> : complex<f32>
+  return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_expm1_with_fmf(
+// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
+  return %expm1 : complex<f32>
+}
+// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
+// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
+// CHECK: return %[[RES]] : complex<f32>
\ No newline at end of file



More information about the Mlir-commits mailing list