[Mlir-commits] [mlir] [mlir][math] `powf(a, b)` drop support when a < 0 (PR #126338)

Hyunsung Lee llvmlistbot at llvm.org
Sun Feb 9 06:44:55 PST 2025


================
@@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+                                          PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  auto baseType = operandB.getType();
+
+  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+                  .getFloatSemantics();
+
+  auto valueB = APFloat(sem);
+  if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+    // Not a constant, return failure
+    return failure();
+  }
+  float floatValueB = valueB.convertToFloat();
+
+  if (floatValueB == 1.0f) {
+    // a^1 -> a
+    rewriter.replaceOp(op, operandA);
+    return success();
+  }
+
+  if (floatValueB == 0.0) {
+    // a^0 -> 1
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    rewriter.replaceOp(op, one);
+    return success();
+  }
+
+  if (floatValueB == 0.5f) {
+    // a^(1/2) -> sqrt(a)
+    Value sqrt = b.create<math::SqrtOp>(operandA);
+    rewriter.replaceOp(op, sqrt);
+    return success();
+  }
+
+  if (floatValueB == -0.5f) {
+    // a^(-1/2) -> 1 / sqrt(a)
+    Value rsqrt = b.create<math::RsqrtOp>(operandA);
+    rewriter.replaceOp(op, rsqrt);
+    return success();
+  }
+
+  if (floatValueB == -1.0f) {
+    // a^(-1) -> 1 / a
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    Value div = b.create<arith::DivFOp>(one, operandA);
+    rewriter.replaceOp(op, div);
+    return success();
+  }
+
+  // Check if the power is an integer
+  if (floatValueB != std::floor(floatValueB)) {
+    // We don't handle non-integer powers here, return failure
+    return failure();
+  }
+
+  auto sign = std::signbit(floatValueB) ? -1 : 1;
+  auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+  auto cstOne =
+      createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+  auto base = operandA;
+  if (sign == -1) {
+    base = b.create<arith::DivFOp>(cstOne, base);
+  }
+  auto current = base;
+  auto result = cstOne;
+  while (absIntValueB > 0) {
+    if (absIntValueB & 1) {
+      result = b.create<arith::MulFOp>(result, current);
+    }
+    current = b.create<arith::MulFOp>(current, current);
+    absIntValueB >>= 1;
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
----------------
ita9naiwa wrote:

Sorry!. I forgot to describe PR in more detail.

This transform should be applied to some small number of 'b' (e.g., when 'abs(b) < 16') but the heuristic number for b is not determined yet.  This last case can be dropped if it's not necessary

https://github.com/llvm/llvm-project/pull/126338


More information about the Mlir-commits mailing list