[Mlir-commits] [mlir] [mlir][SPIR-V] Fix math.powf lowering for non-integer exponents (PR #197727)
Igor Wodiany
llvmlistbot at llvm.org
Fri May 15 06:46:21 PDT 2026
================
@@ -360,62 +362,45 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (!dstType)
return failure();
- // Get the scalar float type.
- FloatType scalarFloatType;
- if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
- scalarFloatType = scalarType;
- } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
- scalarFloatType = cast<FloatType>(vectorType.getElementType());
- } else {
- return failure();
- }
-
- // Get int type of the same shape as the float type.
- Type scalarIntType = rewriter.getIntegerType(32);
- Type intType = scalarIntType;
+ Location loc = powfOp.getLoc();
auto operandType = adaptor.getRhs().getType();
- if (auto vectorType = dyn_cast<VectorType>(operandType)) {
- auto shape = vectorType.getShape();
- intType = VectorType::get(shape, scalarIntType);
+
+ // ConvertFToS-based parity needs an integer-valued exponent. Otherwise
+ // fall back to exp(y*log(x)), which yields NaN for x<0 (matches C).
+ auto isIntegerValuedConstant = [](Value v) -> bool {
+ Attribute attr;
+ if (!matchPattern(v, m_Constant(&attr)))
+ return false;
+ return TypeSwitch<Attribute, bool>(attr)
+ .Case([](FloatAttr a) { return a.getValue().isInteger(); })
+ .Case([](SplatElementsAttr a) {
+ return a.getSplatValue<APFloat>().isInteger();
+ })
+ .Case([](DenseElementsAttr a) {
+ return llvm::all_of(a.getValues<APFloat>(),
+ [](const APFloat &v) { return v.isInteger(); });
+ })
+ .Default(false);
+ };
+
+ if (!isIntegerValuedConstant(adaptor.getRhs())) {
+ Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
+ Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
+ rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul);
+ return success();
}
- // Per GL Pow extended instruction spec:
- // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
- Location loc = powfOp.getLoc();
+ // GL.Pow is undefined for x < 0; take abs and conditionally negate the
+ // result when the exponent is odd.
+ Type intType = rewriter.getIntegerType(32);
+ if (auto vectorType = dyn_cast<VectorType>(operandType))
+ intType = VectorType::get(vectorType.getShape(), intType);
+
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
+ Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
- // Per C/C++ spec:
- // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
- // > finite and negative and exponent is finite and non-integer.
- // Calculate the reminder from the exponent and check whether it is zero.
- Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
- Value expRem =
- spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
- Value expRemNonZero =
- spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
- Value cmpNegativeWithFractionalExp =
- spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
- // Create NaN result and replace base value if conditions are met.
- const auto &floatSemantics = scalarFloatType.getFloatSemantics();
- const auto nan = APFloat::getNaN(floatSemantics);
- Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
- if (auto vectorType = dyn_cast<VectorType>(operandType))
- nanAttr = DenseElementsAttr::get(vectorType, nan);
-
- Value nanValue =
- spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
- Value lhs =
- spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- nanValue, adaptor.getLhs());
- Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
-
- // TODO: The following just forcefully casts y into an integer value in
- // order to properly propagate the sign, assuming integer y cases. It
- // doesn't cover other cases and should be fixed.
-
- // Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =
----------------
IgWod wrote:
That makes sense. I missed the fact the value type is float. In that case is my suggestion around simplify the lowering sensible?
https://github.com/llvm/llvm-project/pull/197727
More information about the Mlir-commits
mailing list