[Mlir-commits] [mlir] [mlir][SPIR-V] Fix math.powf lowering for non-integer exponents (PR #197727)

Igor Wodiany llvmlistbot at llvm.org
Fri May 15 06:46:21 PDT 2026


================
@@ -360,62 +362,45 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     if (!dstType)
       return failure();
 
-    // Get the scalar float type.
-    FloatType scalarFloatType;
-    if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
-      scalarFloatType = scalarType;
-    } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
-      scalarFloatType = cast<FloatType>(vectorType.getElementType());
-    } else {
-      return failure();
-    }
-
-    // Get int type of the same shape as the float type.
-    Type scalarIntType = rewriter.getIntegerType(32);
-    Type intType = scalarIntType;
+    Location loc = powfOp.getLoc();
     auto operandType = adaptor.getRhs().getType();
-    if (auto vectorType = dyn_cast<VectorType>(operandType)) {
-      auto shape = vectorType.getShape();
-      intType = VectorType::get(shape, scalarIntType);
+
+    // ConvertFToS-based parity needs an integer-valued exponent. Otherwise
+    // fall back to exp(y*log(x)), which yields NaN for x<0 (matches C).
+    auto isIntegerValuedConstant = [](Value v) -> bool {
+      Attribute attr;
+      if (!matchPattern(v, m_Constant(&attr)))
+        return false;
+      return TypeSwitch<Attribute, bool>(attr)
+          .Case([](FloatAttr a) { return a.getValue().isInteger(); })
+          .Case([](SplatElementsAttr a) {
+            return a.getSplatValue<APFloat>().isInteger();
+          })
+          .Case([](DenseElementsAttr a) {
+            return llvm::all_of(a.getValues<APFloat>(),
+                                [](const APFloat &v) { return v.isInteger(); });
+          })
+          .Default(false);
+    };
+
+    if (!isIntegerValuedConstant(adaptor.getRhs())) {
+      Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
+      Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
+      rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul);
+      return success();
     }
 
-    // Per GL Pow extended instruction spec:
-    // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
-    Location loc = powfOp.getLoc();
+    // GL.Pow is undefined for x < 0; take abs and conditionally negate the
+    // result when the exponent is odd.
+    Type intType = rewriter.getIntegerType(32);
+    if (auto vectorType = dyn_cast<VectorType>(operandType))
+      intType = VectorType::get(vectorType.getShape(), intType);
+
     Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
     Value lessThan =
         spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
+    Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
 
-    // Per C/C++ spec:
-    // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
-    // > finite and negative and exponent is finite and non-integer.
-    // Calculate the reminder from the exponent and check whether it is zero.
-    Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
-    Value expRem =
-        spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
-    Value expRemNonZero =
-        spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
-    Value cmpNegativeWithFractionalExp =
-        spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
-    // Create NaN result and replace base value if conditions are met.
-    const auto &floatSemantics = scalarFloatType.getFloatSemantics();
-    const auto nan = APFloat::getNaN(floatSemantics);
-    Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
-    if (auto vectorType = dyn_cast<VectorType>(operandType))
-      nanAttr = DenseElementsAttr::get(vectorType, nan);
-
-    Value nanValue =
-        spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
-    Value lhs =
-        spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
-                                nanValue, adaptor.getLhs());
-    Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
-
-    // TODO: The following just forcefully casts y into an integer value in
-    // order to properly propagate the sign, assuming integer y cases. It
-    // doesn't cover other cases and should be fixed.
-
-    // Cast exponent to integer and calculate exponent % 2 != 0.
     Value intRhs =
----------------
IgWod wrote:

That makes sense. I missed the fact the value type is float. In that case is my suggestion around simplify the lowering sensible?

https://github.com/llvm/llvm-project/pull/197727


More information about the Mlir-commits mailing list