[Mlir-commits] [mlir] [mlir][math] `powf(a, b)` drop support when a < 0 (PR #126338)

Benoit Jacob llvmlistbot at llvm.org
Sun Feb 9 18:39:08 PST 2025


================
@@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+                                          PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  auto baseType = operandB.getType();
+
+  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+                  .getFloatSemantics();
+
+  auto valueB = APFloat(sem);
+  if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+    // Not a constant, return failure
+    return failure();
+  }
+  float floatValueB = valueB.convertToFloat();
+
+  if (floatValueB == 1.0f) {
+    // a^1 -> a
+    rewriter.replaceOp(op, operandA);
+    return success();
+  }
+
+  if (floatValueB == 0.0) {
+    // a^0 -> 1
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    rewriter.replaceOp(op, one);
+    return success();
+  }
+
+  if (floatValueB == 0.5f) {
+    // a^(1/2) -> sqrt(a)
+    Value sqrt = b.create<math::SqrtOp>(operandA);
+    rewriter.replaceOp(op, sqrt);
+    return success();
+  }
+
+  if (floatValueB == -0.5f) {
+    // a^(-1/2) -> 1 / sqrt(a)
+    Value rsqrt = b.create<math::RsqrtOp>(operandA);
+    rewriter.replaceOp(op, rsqrt);
+    return success();
+  }
+
+  if (floatValueB == -1.0f) {
+    // a^(-1) -> 1 / a
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    Value div = b.create<arith::DivFOp>(one, operandA);
+    rewriter.replaceOp(op, div);
+    return success();
+  }
+
+  // Check if the power is an integer
+  if (floatValueB != std::floor(floatValueB)) {
+    // We don't handle non-integer powers here, return failure
+    return failure();
+  }
+
+  auto sign = std::signbit(floatValueB) ? -1 : 1;
+  auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+  auto cstOne =
+      createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+  auto base = operandA;
+  if (sign == -1) {
+    base = b.create<arith::DivFOp>(cstOne, base);
+  }
+  auto current = base;
+  auto result = cstOne;
+  while (absIntValueB > 0) {
+    if (absIntValueB & 1) {
+      result = b.create<arith::MulFOp>(result, current);
+    }
+    current = b.create<arith::MulFOp>(current, current);
+    absIntValueB >>= 1;
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
----------------
bjacob wrote:

Just drop the comment `// Restricting a >= 0` here.

Mathematically, the power operation `a^b`, is well-defined in two separate (though overlapping) cases:
1. When `a > 0`. In that case, `a^b` is defined as `exp(b * ln(a))`.
2. When `b` is an integer. In that case, `a^b` is defined as `a * ... * a`, (`b` times), or the reciprocal of that if `b` is negative.

These two definitions agree in the intersection of these two cases.

Because "power" has inherently that two-mode definition, the MLIR op `powf` should have been specified from the start to implement one of these two modes only. Obviously it should have been `a > 0`.

I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for `powf` ops have been broken outside of the case `a > 0`, suggesting that no one was relying on that.

But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of `pow(a, 2.0)`) or that are explicitly not applying to the other case anyway (e.g. the case of `pow(a, 0.5)`).




https://github.com/llvm/llvm-project/pull/126338


More information about the Mlir-commits mailing list