[Mlir-commits] [mlir] [mlir][arith] wide integer emulation support for fpto*i ops (PR #132375)

Jakub Kuderski llvmlistbot at llvm.org
Fri Mar 21 07:48:45 PDT 2025


================
@@ -974,6 +974,126 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertFPToSI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    /* Get the input float type */
+    auto inFp = adaptor.getIn();
+    auto fpTy = inFp.getType();
+    auto fpElemTy = getElementTypeOrSelf(fpTy);
+
+    Type intTy = op.getType();
+    unsigned oldBitWidth = getElementTypeOrSelf(intTy).getIntOrFloatBitWidth();
+
+    auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", intTy));
+
+    /*
+    Work on the absolute value and then convert the result to signed integer.
+    Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e.
+    if the fp is representable in signed i2N, emits the correct result.
+    Else, the result is UB.
+    */
+    TypedAttr zeroAttr = rewriter.getFloatAttr(fpElemTy, 0.0);
+
+    if (auto vecTy = dyn_cast<VectorType>(fpTy))
+      zeroAttr = SplatElementsAttr::get(vecTy, zeroAttr);
+
+    Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+
+    Value oneCst = createScalarOrSplatConstant(rewriter, loc, intTy, 1);
+    Value allOnesCst = createScalarOrSplatConstant(
+        rewriter, loc, intTy, APInt::getAllOnes(oldBitWidth));
+
+    /* Get the absolute value */
+    Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+                                                 inFp, zeroCst);
+    Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
+
+    Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
+
+    /* Defer the absolute value to fptoui */
+    Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
+
+    /* Negate the value if < 0 */
+    Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, res, allOnesCst);
+    Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertFPToUI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    /* Get the input float type */
+    auto inFp = adaptor.getIn();
+    auto fpTy = inFp.getType();
+
+    Type intTy = op.getType();
+    auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", intTy));
+    unsigned newBitWidth = newTy.getElementTypeBitWidth();
+    Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
+    if (auto vecType = dyn_cast<VectorType>(fpTy))
+      newHalfType = VectorType::get(vecType.getShape(), newHalfType);
+    /*
+    The resulting integer has the upper part and the lower part.
+    This would be interpreted as 2^N * high + low, where N is the bitwidth.
+    Therefore, to calculate the higher part, we emit resHigh = fptoui(fp/2^N).
+    For the lower part, we emit fptoui(fp - resHigh * 2^N).
+    The special cases of overflows including +-inf, NaNs and negative numbers
+    are UB.
+    */
+    double powBitwidth = (uint64_t(1) << newBitWidth);
+    TypedAttr powBitwidthAttr =
+        FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
+    if (auto vecType = dyn_cast<VectorType>(fpTy))
+      powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
+    Value powBitwidthFloatCst =
+        rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
+
+    Value fpDivPowBitwidth =
+        rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
+    Value resHigh =
+        rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
+    // Calculate fp - resHigh * 2^N by getting the remainder of the division
+    Value remainder =
+        rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
+    Value resLow =
+        rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
+
+    auto high = appendX1Dim(rewriter, loc, resHigh);
+    auto low = appendX1Dim(rewriter, loc, resLow);
+
+    auto resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
----------------
kuhar wrote:

Also here: do not use auto

https://github.com/llvm/llvm-project/pull/132375


More information about the Mlir-commits mailing list