[Mlir-commits] [mlir] 34ba907 - [mlir][complex] Support Fastmath flag in conversion of complex.sqrt to standard (#85019)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 13 23:53:32 PDT 2024
Author: Kai Sasaki
Date: 2024-03-14T15:53:28+09:00
New Revision: 34ba90745fa55777436a2429a51a3799c83c6d4c
URL: https://github.com/llvm/llvm-project/commit/34ba90745fa55777436a2429a51a3799c83c6d4c
DIFF: https://github.com/llvm/llvm-project/commit/34ba90745fa55777436a2429a51a3799c83c6d4c.diff
LOG: [mlir][complex] Support Fastmath flag in conversion of complex.sqrt to standard (#85019)
When converting complex.sqrt op to standard, we need to keep the fast
math flag given to the op.
See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 12a5e202685885..76729278ec1b46 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -845,6 +845,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
auto type = cast<ComplexType>(op.getType());
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
@@ -852,14 +853,14 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
- Value absLhs = b.create<math::AbsFOp>(real);
- Value absArg = b.create<complex::AbsOp>(elementType, arg);
- Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
+ Value absLhs = b.create<math::AbsFOp>(real, fmf);
+ Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
+ Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);
Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
- Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
- Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
+ Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
+ Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);
Value realIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
@@ -869,7 +870,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
Value resultReal = sqrtAddAbs;
Value imagDivTwoResultReal = b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultReal, resultReal));
+ imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);
Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
@@ -882,7 +883,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
resultReal = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultImag, resultImag)),
+ imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
resultReal);
Value realIsZero =
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index c80ba48cf8c1a3..5918ff2e0f36c8 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -708,11 +708,60 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
// -----
// CHECK-LABEL: func @complex_sqrt
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
+// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
+// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
+// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
+// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
+// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
+// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
+// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
+// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
+// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
+// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
+// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
+// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
+// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
+// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
+// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
+// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
+// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
+// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
+// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
+// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
+// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
+// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
+// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
+// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
+// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
+// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
+// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
+// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
+// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
+// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
+// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
+// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
+// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
+// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
+// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
+// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
+// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
+// CHECK: return %[[VAR41]] : complex<f32>
+
// -----
// CHECK-LABEL: func @complex_conj
@@ -1254,42 +1303,42 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR187:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR188:.*]] = complex.im %[[VAR186]] : complex<f32>
-// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] : f32
+// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_7:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST_8:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR190:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR191:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[VAR192:.*]] = arith.cmpf oeq, %[[VAR190]], %[[CST_7]] : f32
// CHECK: %[[VAR193:.*]] = arith.cmpf oeq, %[[VAR191]], %[[CST_7]] : f32
-// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] : f32
-// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] : f32
-// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] : f32
-// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] : f32
-// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] : f32
-// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] : f32
-// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] : f32
-// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] : f32
-// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] : f32
-// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] : f32
-// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] : f32
-// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] : f32
+// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR206:.*]] = arith.cmpf ogt, %[[VAR190]], %[[VAR191]] : f32
// CHECK: %[[VAR207:.*]] = arith.select %[[VAR206]], %[[VAR199]], %[[VAR205]] : f32
// CHECK: %[[VAR208:.*]] = arith.select %[[VAR193]], %[[VAR198]], %[[VAR207]] : f32
// CHECK: %[[VAR209:.*]] = arith.select %[[VAR192]], %[[VAR204]], %[[VAR208]] : f32
-// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] : f32
+// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_9:.*]] = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] : f32
-// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] : f32
+// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR213:.*]] = arith.cmpf olt, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR214:.*]] = arith.cmpf olt, %[[VAR188]], %[[CST_6]] : f32
-// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] : f32
-// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] : f32
+// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR217:.*]] = arith.negf %[[VAR212]] : f32
// CHECK: %[[VAR218:.*]] = arith.select %[[VAR214]], %[[VAR217]], %[[VAR212]] : f32
// CHECK: %[[VAR219:.*]] = arith.select %[[VAR213]], %[[VAR218]], %[[VAR216]] : f32
-// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] : f32
-// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] : f32
+// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR222:.*]] = arith.select %[[VAR213]], %[[VAR221]], %[[VAR212]] : f32
// CHECK: %[[VAR223:.*]] = arith.cmpf oeq, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR224:.*]] = arith.cmpf oeq, %[[VAR188]], %[[CST_6]] : f32
@@ -1728,3 +1777,60 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.sqrt %arg fastmath<nnan,contract> : complex<f32>
+ return %sqrt : complex<f32>
+}
+
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] fastmath<nnan,contract> : f32
+// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
+// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
+// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
+// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
+// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
+// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
+// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] fastmath<nnan,contract> : f32
+// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
+// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
+// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
+// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
+// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
+// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
+// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
+// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
+// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
+// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
+// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
+// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
+// CHECK: return %[[VAR41]] : complex<f32>
More information about the Mlir-commits
mailing list