[Mlir-commits] [mlir] Fix complex power for large inputs. (PR #88387)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 11 05:59:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Johannes Reifferscheid (jreiffers)
<details>
<summary>Changes</summary>
For example, 1e30^1.2 currently overflows.
Also forward fastmath flags.
This ports XLA's logic and was verified with its test suite. Note that rsqrt and sqrt are still broken.
---
Full diff: https://github.com/llvm/llvm-project/pull/88387.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+95-54)
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+20-1)
``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 462036e51a1f1c..9c82e8105f06e5 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -989,65 +989,107 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
}
};
-/// Coverts x^y = (a+bi)^(c+di) to
+/// Converts lhs^y = (a+bi)^(c+di) to
/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
- ComplexType type, Value a, Value b, Value c,
- Value d) {
+ ComplexType type, Value lhs, Value c, Value d,
+ arith::FastMathFlags fmf) {
auto elementType = cast<FloatType>(type.getElementType());
- // Compute (a*a+b*b)^(0.5c).
- Value aaPbb = builder.create<arith::AddFOp>(
- builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
- Value half = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0.5));
- Value halfC = builder.create<arith::MulFOp>(half, c);
- Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
-
- // Compute exp(-d*atan2(b,a)).
- Value negD = builder.create<arith::NegFOp>(d);
- Value argX = builder.create<math::Atan2Op>(b, a);
- Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
- Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
-
- // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
- Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
-
- // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
- Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
- Value halfD = builder.create<arith::MulFOp>(half, d);
- Value q = builder.create<arith::AddFOp>(
- builder.create<arith::MulFOp>(c, argX),
- builder.create<arith::MulFOp>(halfD, lnAaPbb));
-
- Value cosQ = builder.create<math::CosOp>(q);
- Value sinQ = builder.create<math::SinOp>(q);
+ Value a = builder.create<complex::ReOp>(lhs);
+ Value b = builder.create<complex::ImOp>(lhs);
+
+ Value abs = builder.create<complex::AbsOp>(lhs, fmf);
+ Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
+
+ Value negD = builder.create<arith::NegFOp>(d, fmf);
+ Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
+ Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
+ Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
+
+ Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
+ Value lnAbs = builder.create<math::LogOp>(abs, fmf);
+ Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
+ Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
+ Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
+ Value cosQ = builder.create<math::CosOp>(q, fmf);
+ Value sinQ = builder.create<math::SinOp>(q, fmf);
+
+ Value inf = builder.create<arith::ConstantOp>(
+ elementType,
+ builder.getFloatAttr(elementType,
+ APFloat::getInf(elementType.getFloatSemantics())));
Value zero = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0));
+ elementType, builder.getFloatAttr(elementType, 0.0));
Value one = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 1));
-
- Value xEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
- Value yGeZero = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
- Value cEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
- Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+ elementType, builder.getFloatAttr(elementType, 1.0));
Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
- Value complexOther = builder.create<complex::CreateOp>(
- type, builder.create<arith::MulFOp>(coeff, cosQ),
- builder.create<arith::MulFOp>(coeff, sinQ));
+ Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+ Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
- // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
+ // Case 0:
+ // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
+ // Branch Cuts for Complex Elementary Functions or Much Ado About
+ // Nothing's Sign Bit, W. Kahan, Section 10.
+ Value absEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
+ Value dEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
+ Value cEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
+ Value bEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
+
+ Value zeroLeC =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
+ Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
+ Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
+ Value complexOneOrZero =
+ builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
+ Value coeffCosSin =
+ builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
+ Value cutoff0 = builder.create<arith::SelectOp>(
+ builder.create<arith::AndIOp>(
+ builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
+ complexOneOrZero, coeffCosSin);
+
+ // Case 1:
+ // x^0 is defined to be 1 for any x, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
- return builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(xEqZero, yGeZero),
- builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
- complexOther);
+ Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
+ Value cutoff1 =
+ builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
+
+ // Case 2:
+ // 1^(c + d*i) = 1 + 0*i
+ Value lhsEqOne = builder.create<arith::AndIOp>(
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
+ bEqZero);
+ Value cutoff2 =
+ builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
+
+ // Case 3:
+ // inf^(c + 0*i) = inf + 0*i, c > 0
+ Value lhsEqInf = builder.create<arith::AndIOp>(
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
+ bEqZero);
+ Value rhsGt0 = builder.create<arith::AndIOp>(
+ dEqZero,
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
+ Value cutoff3 = builder.create<arith::SelectOp>(
+ builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
+
+ // Case 4:
+ // inf^(c + 0*i) = 0 + 0*i, c < 0
+ Value rhsLt0 = builder.create<arith::AndIOp>(
+ dEqZero,
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
+ Value cutoff4 = builder.create<arith::SelectOp>(
+ builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
+
+ return cutoff4;
}
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
@@ -1060,12 +1102,11 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
auto type = cast<ComplexType>(adaptor.getLhs().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
- rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+ rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
+ c, d, op.getFastmath())});
return success();
}
};
@@ -1080,14 +1121,14 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
Value c = builder.create<arith::ConstantOp>(
elementType, builder.getFloatAttr(elementType, -0.5));
Value d = builder.create<arith::ConstantOp>(
elementType, builder.getFloatAttr(elementType, 0));
- rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+ rewriter.replaceOp(op,
+ {powOpConversionImpl(builder, type, adaptor.getComplex(),
+ c, d, op.getFastmath())});
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a1de61d10bb226..8d2fb09daa87b6 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -753,13 +753,32 @@ func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
// -----
-// CHECK-LABEL: func.func @complex_pow
+// CHECK-LABEL: func.func @complex_pow
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_pow(%lhs: complex<f32>,
%rhs: complex<f32>) -> complex<f32> {
%pow = complex.pow %lhs, %rhs : complex<f32>
return %pow : complex<f32>
}
+// CHECK: %[[A:.*]] = complex.re %[[LHS]]
+// CHECK: %[[B:.*]] = complex.im %[[LHS]]
+// CHECK: math.atan2 %[[B]], %[[A]] : f32
+
+// -----
+
+// CHECK-LABEL: func.func @complex_pow_with_fmf
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+func.func @complex_pow_with_fmf(%lhs: complex<f32>,
+ %rhs: complex<f32>) -> complex<f32> {
+ %pow = complex.pow %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+ return %pow : complex<f32>
+}
+
+// CHECK: %[[A:.*]] = complex.re %[[LHS]]
+// CHECK: %[[B:.*]] = complex.im %[[LHS]]
+// CHECK: math.atan2 %[[B]], %[[A]] fastmath<nnan,contract> : f32
+
// -----
// CHECK-LABEL: func.func @complex_rsqrt
``````````
</details>
https://github.com/llvm/llvm-project/pull/88387
More information about the Mlir-commits
mailing list