[Mlir-commits] [mlir] [mlir][math] `powf(a, b)` drop support when a < 0 (PR #126338)
Hyunsung Lee
llvmlistbot at llvm.org
Sun Feb 9 06:55:24 PST 2025
================
@@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ auto baseType = operandB.getType();
+
+ auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+ .getFloatSemantics();
+
+ auto valueB = APFloat(sem);
+ if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+ // Not a constant, return failure
+ return failure();
+ }
+ float floatValueB = valueB.convertToFloat();
+
+ if (floatValueB == 1.0f) {
+ // a^1 -> a
+ rewriter.replaceOp(op, operandA);
+ return success();
+ }
+
+ if (floatValueB == 0.0) {
+ // a^0 -> 1
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ rewriter.replaceOp(op, one);
+ return success();
+ }
+
+ if (floatValueB == 0.5f) {
+ // a^(1/2) -> sqrt(a)
+ Value sqrt = b.create<math::SqrtOp>(operandA);
+ rewriter.replaceOp(op, sqrt);
+ return success();
+ }
+
+ if (floatValueB == -0.5f) {
+ // a^(-1/2) -> 1 / sqrt(a)
+ Value rsqrt = b.create<math::RsqrtOp>(operandA);
+ rewriter.replaceOp(op, rsqrt);
+ return success();
+ }
+
+ if (floatValueB == -1.0f) {
+ // a^(-1) -> 1 / a
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ Value div = b.create<arith::DivFOp>(one, operandA);
+ rewriter.replaceOp(op, div);
+ return success();
+ }
+
+ // Check if the power is an integer
+ if (floatValueB != std::floor(floatValueB)) {
+ // We don't handle non-integer powers here, return failure
+ return failure();
+ }
+
+ auto sign = std::signbit(floatValueB) ? -1 : 1;
+ auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+ auto cstOne =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ auto base = operandA;
+ if (sign == -1) {
+ base = b.create<arith::DivFOp>(cstOne, base);
+ }
+ auto current = base;
+ auto result = cstOne;
+ while (absIntValueB > 0) {
+ if (absIntValueB & 1) {
+ result = b.create<arith::MulFOp>(result, current);
+ }
+ current = b.create<arith::MulFOp>(current, current);
+ absIntValueB >>= 1;
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
----------------
ita9naiwa wrote:
One problem is that there are some special use-cases where `var a < 0` but `const b == some multiple of 2` cc @hanhanW
https://github.com/llvm/llvm-project/pull/126338
More information about the Mlir-commits
mailing list