[Mlir-commits] [mlir] [mlir][complex] Support fast math flag for complex.sign op (PR #87148)

Kai Sasaki llvmlistbot at llvm.org
Sat Mar 30 02:21:05 PDT 2024


https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/87148

We are going to support the fast math flag given in `complex.sign` op in the conversion to standard dialect. 

See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

>From 48bfe4826b91fd25ff4e22891f42dca31375de13 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 29 Mar 2024 16:28:30 +0900
Subject: [PATCH] [mlir][complex] Support fast math flag for complex.sign op

---
 .../ComplexToStandard/ComplexToStandard.cpp   |  7 +--
 .../convert-to-standard.mlir                  | 43 +++++++++++++++++++
 2 files changed, 47 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 17f64f1b65b7c4..e18b27702e4ecd 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -918,6 +918,7 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
@@ -928,9 +929,9 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
     Value imagIsZero =
         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-    auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
-    Value realSign = b.create<arith::DivFOp>(real, abs);
-    Value imagSign = b.create<arith::DivFOp>(imag, abs);
+    auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
+    Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
+    Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
                                                  adaptor.getComplex(), sign);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index bac94aae6b746c..112918b20d3305 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1880,3 +1880,46 @@ func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @complex_sign_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %sign = complex.sign %arg fastmath<nnan,contract> : complex<f32>
+  return %sign : complex<f32>
+}
+
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>



More information about the Mlir-commits mailing list