[Mlir-commits] [mlir] [mlir][math] `powf(a, b)` drop support when a < 0 (PR #126338)
Benoit Jacob
llvmlistbot at llvm.org
Sun Feb 9 18:39:08 PST 2025
================
@@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ auto baseType = operandB.getType();
+
+ auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+ .getFloatSemantics();
+
+ auto valueB = APFloat(sem);
+ if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+ // Not a constant, return failure
+ return failure();
+ }
+ float floatValueB = valueB.convertToFloat();
+
+ if (floatValueB == 1.0f) {
+ // a^1 -> a
+ rewriter.replaceOp(op, operandA);
+ return success();
+ }
+
+ if (floatValueB == 0.0) {
+ // a^0 -> 1
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ rewriter.replaceOp(op, one);
+ return success();
+ }
+
+ if (floatValueB == 0.5f) {
+ // a^(1/2) -> sqrt(a)
+ Value sqrt = b.create<math::SqrtOp>(operandA);
+ rewriter.replaceOp(op, sqrt);
+ return success();
+ }
+
+ if (floatValueB == -0.5f) {
+ // a^(-1/2) -> 1 / sqrt(a)
+ Value rsqrt = b.create<math::RsqrtOp>(operandA);
+ rewriter.replaceOp(op, rsqrt);
+ return success();
+ }
+
+ if (floatValueB == -1.0f) {
+ // a^(-1) -> 1 / a
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ Value div = b.create<arith::DivFOp>(one, operandA);
+ rewriter.replaceOp(op, div);
+ return success();
+ }
+
+ // Check if the power is an integer
+ if (floatValueB != std::floor(floatValueB)) {
----------------
bjacob wrote:
Sorry, this is more complicated and fragile than we need this to be :-)
The first problem is that the expression `x == std::floor(x)` is not actually implying that `x` is an integer. For large values of `x`, for instance for `f16` values larger than 2048, most integers are not representable values, so the return value of std::floor(x) won't be an integer in general. It will be some non-integral value `x` for which the condition `x == std::floor(x)` is true, despite it being non-integral.
Don't try to handle arbitrary integer `x`. Just handle the special value `2.0`, and maybe also 3.0 and 4.0 if you want, but that's it. If someone really needs a larger integral exponent to be match, we can always expand this pattern later.
https://github.com/llvm/llvm-project/pull/126338
More information about the Mlir-commits
mailing list