[Mlir-commits] [mlir] 8ddaf75 - Fix rsqrt inaccuracies. (#88691)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 15 02:24:05 PDT 2024
Author: Johannes Reifferscheid
Date: 2024-04-15T11:24:00+02:00
New Revision: 8ddaf750746d7f9b5f7e878870b086edc0f55326
URL: https://github.com/llvm/llvm-project/commit/8ddaf750746d7f9b5f7e878870b086edc0f55326
DIFF: https://github.com/llvm/llvm-project/commit/8ddaf750746d7f9b5f7e878870b086edc0f55326.diff
LOG: Fix rsqrt inaccuracies. (#88691)
The current lowering has issues with large/subnormal values. This ports
XLA's lowering and was verified using XLA's test suite and the
MLIR-based emitters.
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 49eb575212ffc1..3ebee9baff31bd 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -27,9 +27,11 @@ using namespace mlir;
namespace {
-// Returns the absolute value or its square root.
+enum class AbsFn { abs, sqrt, rsqrt };
+
+// Returns the absolute value, its square root or its reciprocal square root.
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
- ImplicitLocOpBuilder &b, bool returnSqrt = false) {
+ ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
Value one = b.create<arith::ConstantOp>(real.getType(),
b.getFloatAttr(real.getType(), 1.0));
@@ -43,7 +45,13 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
Value result;
- if (returnSqrt) {
+ if (fn == AbsFn::rsqrt) {
+ ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
+ min = b.create<math::RsqrtOp>(min, fmf);
+ max = b.create<math::RsqrtOp>(max, fmf);
+ }
+
+ if (fn == AbsFn::sqrt) {
Value quarter = b.create<arith::ConstantOp>(
real.getType(), b.getFloatAttr(real.getType(), 0.25));
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -863,7 +871,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
- Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
+ Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
Value cos = b.create<math::CosOp>(sqrtArg, fmf);
@@ -1147,18 +1155,74 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
LogicalResult
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value c = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, -0.5));
- Value d = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0));
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+
+ auto cst = [&](APFloat v) {
+ return b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, v));
+ };
+ const auto &floatSemantics = elementType.getFloatSemantics();
+ Value zero = cst(APFloat::getZero(floatSemantics));
+ Value inf = cst(APFloat::getInf(floatSemantics));
+ Value negHalf = b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, -0.5));
+ Value nan = cst(APFloat::getNaN(floatSemantics));
+
+ Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
+ Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
+ Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
+ Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
+ Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
+
+ Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
+ Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
+
+ if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+ arith::FastMathFlags::ninf)) {
+ Value negOne = b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, -1));
+
+ Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
+ Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
+ Value negImagSignedZero =
+ b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
- rewriter.replaceOp(op,
- {powOpConversionImpl(builder, type, adaptor.getComplex(),
- c, d, op.getFastmath())});
+ Value absReal = b.create<math::AbsFOp>(real, fmf);
+ Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+ Value absImagIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+ Value realIsNan =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
+ Value realIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
+ Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
+
+ Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
+
+ resultReal =
+ b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
+ resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
+ resultImag);
+ }
+
+ Value isRealZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
+ Value isImagZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
+
+ resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
+ resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index e0e7cdadd317d2..8b4ea9777f7976 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -837,6 +837,21 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
return %rsqrt : complex<f32>
}
+// CHECK-COUNT-5: arith.select
+// CHECK-NOT: arith.select
+
+// -----
+
+// CHECK-LABEL: func @complex_rsqrt_nnan_ninf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_rsqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.rsqrt %arg fastmath<nnan,ninf> : complex<f32>
+ return %sqrt : complex<f32>
+}
+
+// CHECK-COUNT-3: arith.select
+// CHECK-NOT: arith.select
+
// -----
// CHECK-LABEL: func.func @complex_angle
@@ -2103,4 +2118,4 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
-// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
More information about the Mlir-commits
mailing list