[Mlir-commits] [mlir] [mlir][complex] Support Fastmath flag for complex log ops (PR #69798)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 16:52:52 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

<details>
<summary>Changes</summary>

Progressive support of fastmath flag in the conversion of log type ops.

See more detail https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

---
Full diff: https://github.com/llvm/llvm-project/pull/69798.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+21-11) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+48-1) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 174b7ce9fed2df4..bf753c7062f3664 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -499,13 +499,16 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
-    Value resultReal = b.create<math::LogOp>(elementType, abs);
+    Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
+                                         fmf.getValue());
+    Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-    Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
+    Value resultImag =
+        b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
@@ -520,6 +523,7 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
@@ -535,15 +539,21 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
     // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
     // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
     // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
-    Value sumSq = b.create<arith::MulFOp>(real, real);
-    sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
-    sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
-    Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
-    Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
+    Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
+    sumSq = b.create<arith::AddFOp>(
+        sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
+        fmf.getValue());
+    sumSq = b.create<arith::AddFOp>(
+        sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
+        fmf.getValue());
+    Value logSumSq =
+        b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
+    Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
+
+    Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
 
-    Value realPlusOne = b.create<arith::AddFOp>(real, one);
-
-    Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
+    Value resultImag =
+        b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8264382a02651c2..3af28150fd5c3f3 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -797,4 +797,51 @@ func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
 // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RES]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_log_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
+  return %log : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_log1p_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %log1p = complex.log1p %arg fastmath<nnan,contract> : complex<f32>
+  return %log1p : complex<f32>
+}
+
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/69798


More information about the Mlir-commits mailing list