[Mlir-commits] [mlir] [mlir][arith] wide integer emulation support for fpto*i ops (PR #132375)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 27 09:02:52 PDT 2025
https://github.com/egebeysel updated https://github.com/llvm/llvm-project/pull/132375
>From aba2153312dc47f1aceb904d5162c3f5152590d7 Mon Sep 17 00:00:00 2001
From: Ege Beysel <beysel at roofline.ai>
Date: Thu, 27 Mar 2025 14:00:10 +0100
Subject: [PATCH 1/2] [mlir][arith] add wide integer emulation support for subi
& update sitofp to use it
Signed-off-by: Ege Beysel <beysel at roofline.ai>
---
.../Arith/Transforms/EmulateWideInt.cpp | 56 ++++++++++---
mlir/test/Dialect/Arith/emulate-wide-int.mlir | 55 +++++++++++--
.../CPU/test-wide-int-emulation-subi-i32.mlir | 81 +++++++++++++++++++
3 files changed, 173 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..3226b5d99114a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -866,6 +866,46 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertSubI
+//===----------------------------------------------------------------------===//
+
+struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {}", op.getType()));
+
+ Type newElemTy = reduceInnermostDim(newTy);
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ auto [rhsElem0, rhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+ // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
+ // CARRY is 1 or 0.
+ Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+ // We have a carry if lhsElem0 < rhsElem0.
+ Value carry0 = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+ Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+
+ Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
+ Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+
+ Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertSIToFP
//===----------------------------------------------------------------------===//
@@ -885,22 +925,16 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", oldTy));
- unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
- Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
- Value allOnesCst = createScalarOrSplatConstant(
- rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
// To avoid operating on very large unsigned numbers, perform the
// conversion on the absolute value. Then, decide whether to negate the
- // result or not based on that sign bit. We assume two's complement and
- // implement negation by flipping all bits and adding 1.
- // Note that this relies on the the other conversion patterns to legalize
- // created ops and narrow the bit widths.
+ // result or not based on that sign bit. We implement negation by
+ // subtracting from zero. Note that this relies on the the other conversion
+ // patterns to legalize created ops and narrow the bit widths.
Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
in, zeroCst);
- Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
- Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+ Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
@@ -1139,7 +1173,7 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
- ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
+ ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
// Bitwise binary ops.
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ed08779c10266..5603d8e5064cb 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
return %x : vector<4xi64>
}
+// CHECK-LABEL: func @subi_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : i32 from vector<2xi32>
+// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : i1 to i32
+// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32
+// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @subi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.subi %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @subi_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32>
+// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<4x2xi32>
+func.func @subi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+ %x = arith.subi %a, %b : vector<4xi64>
+ return %x : vector<4xi64>
+}
+
// CHECK-LABEL: func.func @cmpi_eq_scalar
// CHECK-SAME: ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>)
// CHECK-NEXT: [[LHSLOW:%.+]] = vector.extract [[LHS]][0] : i32 from vector<2xi32>
@@ -967,11 +1005,12 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
// CHECK-LABEL: func @sitofp_i64_f64
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64
-// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32>
-// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : i32 from vector<2xi32>
-// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : i32 from vector<2xi32>
-// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32
-// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32
+// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK: [[ZERO1:%.+]] = vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[ZERO2:%.+]] = vector.extract [[VZERO]][1] : i32 from vector<2xi32>
+// CHECK: arith.subi [[ZERO1]], {{%.+}} : i32
+// CHECK: arith.subi [[ZERO2]], {{%.+}} : i32
// CHECK: [[CST0:%.+]] = arith.constant 0 : i32
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64
@@ -990,9 +1029,9 @@ func.func @sitofp_i64_f64(%a : i64) -> f64 {
// CHECK-LABEL: func @sitofp_i64_f64_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
-// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32>
-// CHECK: arith.xori
-// CHECK-NEXT: arith.xori
+// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK: arith.subi
+// CHECK: arith.subi
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
new file mode 100644
index 0000000000000..7f0e8fd111028
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
@@ -0,0 +1,81 @@
+// Ops in this function will be emulated using i16 types.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 {
+ %res = arith.subi %arg, %arg0 : i32
+ return %res : i32
+}
+
+func.func @check_subi(%arg : i32, %arg0 : i32) -> () {
+ %res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32)
+ vector.print %res : i32
+ return
+}
+
+func.func @entry() {
+ %lhs1 = arith.constant 1 : i32
+ %rhs1 = arith.constant 2 : i32
+
+ // CHECK: -1
+ func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> ()
+
+ %lhs2 = arith.constant 1 : i32
+ %rhs2 = arith.constant -2 : i32
+
+ // CHECK-NEXT: 3
+ func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> ()
+ // CHECK-NEXT: -3
+ func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> ()
+
+ %lhs3 = arith.constant -1 : i32
+ %rhs3 = arith.constant -2 : i32
+
+ // CHECK-NEXT: 1
+ func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> ()
+ // CHECK-NEXT: -1
+ func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> ()
+
+ // Overflow from the upper/lower part
+ %lhs4 = arith.constant 131074 : i32
+ %rhs4 = arith.constant 3 : i32
+
+ // CHECK-NEXT: 131071
+ func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> ()
+ // CHECK-NEXT: -131071
+ func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> ()
+
+ // Overflow in both parts
+ %lhs5 = arith.constant 16385027 : i32
+ %rhs5 = arith.constant 16450564 : i32
+
+ // CHECK-NEXT: -65537
+ func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> ()
+ // CHECK-NEXT: 65537
+ func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> ()
+
+ // Max/Min unsigned integers
+ %uintmax = arith.constant 2147483647 : i32
+ %uintmin = arith.constant -2147483648 : i32
+
+ // CHECK-NEXT: -1
+ func.call @check_subi(%uintmax, %uintmin) : (i32, i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_subi(%uintmin, %uintmax) : (i32, i32) -> ()
+
+
+ return
+}
>From a86cb196b0b32ec22a2709a88e5faa21bf61d8fd Mon Sep 17 00:00:00 2001
From: Ege Beysel <beysel at roofline.ai>
Date: Thu, 20 Mar 2025 18:18:18 +0100
Subject: [PATCH 2/2] [mlir][arith] add wide integer emulation support for
fpto*i ops
Signed-off-by: Ege Beysel <beysel at roofline.ai>
---
.../Arith/Transforms/EmulateWideInt.cpp | 128 +++++++++++++++++-
mlir/test/Dialect/Arith/emulate-wide-int.mlir | 109 +++++++++++++++
.../test-wide-int-emulation-fptosi-i64.mlir | 89 ++++++++++++
.../test-wide-int-emulation-fptoui-i64.mlir | 64 +++++++++
4 files changed, 389 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 3226b5d99114a..bd3f53955f0fa 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
@@ -1008,6 +1009,130 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertFPToSI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Get the input float type.
+ Value inFp = adaptor.getIn();
+ Type fpTy = inFp.getType();
+
+ Type intTy = op.getType();
+
+ auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", intTy));
+
+ // Work on the absolute value and then convert the result to signed integer.
+ // Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e.
+ // if the fp is representable in signed i2N, emits the correct result.
+ // Else, the result is UB.
+
+ TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
+ Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+
+ Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0);
+
+ // Get the absolute value. One could have used math.absf here, but that
+ // introduces an extra dependency.
+ Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+ inFp, zeroCst);
+ Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
+
+ Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
+
+ // Defer the absolute value to fptoui.
+ Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
+
+ // Negate the value if < 0 .
+ Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res);
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertFPToUI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Get the input float type.
+ Value inFp = adaptor.getIn();
+ Type fpTy = inFp.getType();
+
+ Type intTy = op.getType();
+ auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", intTy));
+ unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+ Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
+ if (auto vecType = dyn_cast<VectorType>(fpTy))
+ newHalfType = VectorType::get(vecType.getShape(), newHalfType);
+
+ // The resulting integer has the upper part and the lower part.
+ // This would be interpreted as 2^N * high + low, where N is the bitwidth.
+ // Therefore, to calculate the higher part, we emit resHigh =
+ // fptoui(fp/2^N). For the lower part, we emit fptoui(fp - resHigh * 2^N).
+ // The special cases of overflows including +-inf, NaNs and negative numbers
+ // are UB.
+
+ const llvm::fltSemantics &fSemantics =
+ cast<FloatType>(getElementTypeOrSelf(fpTy)).getFloatSemantics();
+
+ auto powBitwidth = llvm::APFloat(fSemantics);
+ // If the integer does not fit the floating point number, we set the
+ // powBitwidth to inf. This ensures that the upper part is set
+ // correctly to 0. The opStatus inexact here only occurs when we have an
+ // overflow, since the number is always a power of two.
+ if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
+ false, llvm::RoundingMode::TowardZero) ==
+ llvm::detail::opStatus::opInexact)
+ powBitwidth = llvm::APFloat::getInf(fSemantics);
+
+ TypedAttr powBitwidthAttr =
+ FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
+ if (auto vecType = dyn_cast<VectorType>(fpTy))
+ powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
+ Value powBitwidthFloatCst =
+ rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
+
+ Value fpDivPowBitwidth =
+ rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
+ Value resHigh =
+ rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
+ // Calculate fp - resHigh * 2^N by getting the remainder of the division
+ Value remainder =
+ rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
+ Value resLow =
+ rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
+
+ Value high = appendX1Dim(rewriter, loc, resHigh);
+ Value low = appendX1Dim(rewriter, loc, resLow);
+
+ Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertTruncI
//===----------------------------------------------------------------------===//
@@ -1184,5 +1309,6 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
- ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
+ ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 5603d8e5064cb..6f4e3b1d8f67a 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -1046,3 +1046,112 @@ func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
%r = arith.sitofp %a : vector<3xi64> to vector<3xf64>
return %r : vector<3xf64>
}
+
+// CHECK-LABEL: func @fptoui_i64_f64
+// CHECK-SAME: ([[ARG:%.+]]: f64) -> vector<2xi32>
+// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : f64
+// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
+// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : f64
+// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
+// CHECK: %{{.+}} = vector.insert [[LOWHALF]], %{{.+}} [0]
+// CHECK-NEXT: [[RESVEC:%.+]] = vector.insert [[HIGHHALF]], %{{.+}} [1]
+// CHECK: return [[RESVEC]] : vector<2xi32>
+func.func @fptoui_i64_f64(%a : f64) -> i64 {
+ %r = arith.fptoui %a : f64 to i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func @fptoui_i64_f64_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
+// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
+// CHECK-DAG: [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-DAG: [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK: %{{.+}} = vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]}
+// CHECK-NEXT: [[RESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]], %{{.+}} {offsets = [0, 1], strides = [1, 1]}
+// CHECK: return [[RESVEC]] : vector<3x2xi32>
+func.func @fptoui_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
+ %r = arith.fptoui %a : vector<3xf64> to vector<3xi64>
+ return %r : vector<3xi64>
+}
+
+// This generates lines that are already verified by other patterns.
+// We do not re-verify these and just check for the wrapper around fptoui by following its low part.
+// CHECK-LABEL: func @fptosi_i64_f64
+// CHECK-SAME: ([[ARG:%.+]]: f64) -> vector<2xi32>
+// CHECK: [[ZEROCST:%.+]] = arith.constant 0.000000e+00 : f64
+// CHECK: [[ZEROCSTINT:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : f64
+// CHECK-NEXT: [[NEGATED:%.+]] = arith.negf [[ARG]] : f64
+// CHECK-NEXT: [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : f64
+// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : f64
+// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
+// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : f64
+// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
+// CHECK: vector.insert [[LOWHALF]], %{{.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[FPTOUIRESVEC:%.+]] = vector.insert [[HIGHHALF]]
+// CHECK: [[ZEROCSTINTHALF:%.+]] = vector.extract [[ZEROCSTINT]][0] : i32 from vector<2xi32>
+// CHECK: [[SUB:%.+]] = arith.subi [[ZEROCSTINTHALF]], %{{.+}} : i32
+// CHECK-NEXT: arith.cmpi ult, [[ZEROCSTINTHALF]], %{{.+}} : i32
+// CHECK-NEXT: arith.extui
+// CHECK-NEXT: arith.subi
+// CHECK-NEXT: arith.subi
+// CHECK: vector.insert [[SUB]]
+// CHECK: [[SUBVEC:%.+]] = vector.insert
+// CHECK: [[SUB:%.+]] = vector.extract [[SUBVEC]][0] : i32 from vector<2xi32>
+// CHECK: [[LOWRES:%.+]] = vector.extract [[FPTOUIRESVEC]][0] : i32 from vector<2xi32>
+// CHECK: [[ABSRES:%.+]] = arith.select [[ISNEGATIVE]], [[SUB]], [[LOWRES]] : i32
+// CHECK-NEXT: arith.select [[ISNEGATIVE]]
+// CHECK: vector.insert [[ABSRES]]
+// CHECK-NEXT: [[ABSRESVEC:%.+]] = vector.insert
+// CHECK-NEXT: return [[ABSRESVEC]] : vector<2xi32>
+func.func @fptosi_i64_f64(%a : f64) -> i64 {
+ %r = arith.fptosi %a : f64 to i64
+ return %r : i64
+}
+
+// Same as the non-vector one, we don't re-verify.
+// CHECK-LABEL: func @fptosi_i64_f64_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
+// CHECK-NEXT: [[ZEROCST:%.+]] = arith.constant dense<0.000000e+00> : vector<3xf64>
+// CHECK-NEXT: [[ZEROCSTINT:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT: [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : vector<3xf64>
+// CHECK-NEXT: [[NEGATED:%.+]] = arith.negf [[ARG]] : vector<3xf64>
+// CHECK-NEXT: [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT: [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK: vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT: [[FPTOUIRESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]]
+// CHECK: [[ZEROCSTINTHALF:%.+]] = vector.extract_strided_slice [[ZEROCSTINT]]
+// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK: [[SUB:%.+]] = arith.subi [[ZEROCSTINTHALF]], %{{.+}} : vector<3x1xi32>
+// CHECK-NEXT: arith.cmpi ult, [[ZEROCSTINTHALF]], %{{.+}} : vector<3x1xi32>
+// CHECK-NEXT: arith.extui
+// CHECK-NEXT: arith.subi
+// CHECK-NEXT: arith.subi
+// CHECK: vector.insert_strided_slice [[SUB]]
+// CHECK-NEXT: [[SUBVEC:%.+]] = vector.insert_strided_slice
+// CHECK: [[SUB:%.+]] = vector.extract_strided_slice [[SUBVEC]]
+// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK: [[LOWRES:%.+]] = vector.extract_strided_slice [[FPTOUIRESVEC]]
+// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK: [[ISNEGATIVEX1:%.+]] = vector.shape_cast [[ISNEGATIVE]] : vector<3xi1> to vector<3x1xi1>
+// CHECK: [[ABSRES:%.+]] = arith.select [[ISNEGATIVEX1]], [[SUB]], [[LOWRES]] : vector<3x1xi1>, vector<3x1xi32>
+// CHECK-NEXT: arith.select [[ISNEGATIVEX1]]
+// CHECK: vector.insert_strided_slice [[ABSRES]]
+// CHECK-NEXT: [[ABSRESVEC:%.+]] = vector.insert_strided_slice
+// CHECK-NEXT: return [[ABSRESVEC]] : vector<3x2xi32>
+func.func @fptosi_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
+ %r = arith.fptosi %a : vector<3xf64> to vector<3xi64>
+ return %r : vector<3xi64>
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
new file mode 100644
index 0000000000000..d93b834c8f919
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
@@ -0,0 +1,89 @@
+// Check that the wide integer `arith.fptosi` emulation produces the same result as wide
+// `arith.fptosi`. Emulate i64 ops with i32 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=32" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i32 types.
+func.func @emulate_fptosi(%arg: f64) -> i64 {
+ %res = arith.fptosi %arg : f64 to i64
+ return %res : i64
+}
+
+func.func @check_fptosi(%arg : f64) -> () {
+ %res = func.call @emulate_fptosi(%arg) : (f64) -> (i64)
+ vector.print %res : i64
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0.0 : f64
+ %cst_nzero = arith.constant 0x8000000000000000 : f64
+ %cst1 = arith.constant 1.0 : f64
+ %cst_n1 = arith.constant -1.0 : f64
+ %cst_n1_5 = arith.constant -1.5 : f64
+
+ %cstpow20 = arith.constant 1048576.0 : f64
+ %cstnpow20 = arith.constant -1048576.0 : f64
+
+ %cst_i32_max = arith.constant 4294967295.0 : f64
+ %cst_i32_min = arith.constant -4294967296.0 : f64
+ %cst_i32_overflow = arith.constant 4294967296.0 : f64
+ %cst_i32_noverflow = arith.constant -4294967297.0 : f64
+
+
+ %cstpow40 = arith.constant 1099511627776.0 : f64
+ %cstnpow40 = arith.constant -1099511627776.0 : f64
+ %cst_pow40ppow20 = arith.constant 1099512676352.0 : f64
+ %cst_npow40ppow20 = arith.constant -1099512676352.0 : f64
+
+ %cst_max = arith.constant 9007199254740992.0
+ %cst_min = arith.constant -9007199254740992.0
+
+ // CHECK: 0
+ func.call @check_fptosi(%cst0) : (f64) -> ()
+ // CHECK-NEXT: 0
+ func.call @check_fptosi(%cst_nzero) : (f64) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_fptosi(%cst1) : (f64) -> ()
+ // CHECK-NEXT: -1
+ func.call @check_fptosi(%cst_n1) : (f64) -> ()
+ // CHECK-NEXT: -1
+ func.call @check_fptosi(%cst_n1_5) : (f64) -> ()
+ // CHECK-NEXT: 1048576
+ func.call @check_fptosi(%cstpow20) : (f64) -> ()
+ // CHECK-NEXT: -1048576
+ func.call @check_fptosi(%cstnpow20) : (f64) -> ()
+ // CHECK-NEXT: 4294967295
+ func.call @check_fptosi(%cst_i32_max) : (f64) -> ()
+ // CHECK-NEXT: -4294967296
+ func.call @check_fptosi(%cst_i32_min) : (f64) -> ()
+ // CHECK-NEXT: 4294967296
+ func.call @check_fptosi(%cst_i32_overflow) : (f64) -> ()
+ // CHECK-NEXT: -4294967297
+ func.call @check_fptosi(%cst_i32_noverflow) : (f64) -> ()
+ // CHECK-NEXT: 1099511627776
+ func.call @check_fptosi(%cstpow40) : (f64) -> ()
+ // CHECK-NEXT: -1099511627776
+ func.call @check_fptosi(%cstnpow40) : (f64) -> ()
+ // CHECK-NEXT: 1099512676352
+ func.call @check_fptosi(%cst_pow40ppow20) : (f64) -> ()
+ // CHECK-NEXT: -1099512676352
+ func.call @check_fptosi(%cst_npow40ppow20) : (f64) -> ()
+ // CHECK-NEXT: 9007199254740992
+ func.call @check_fptosi(%cst_max) : (f64) -> ()
+ // CHECK-NEXT: -9007199254740992
+ func.call @check_fptosi(%cst_min) : (f64) -> ()
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir
new file mode 100644
index 0000000000000..81283ee9fdfd8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir
@@ -0,0 +1,64 @@
+// Check that the wide integer `arith.fptoui` emulation produces the same result as wide
+// `arith.fptoui`. Emulate i64 ops with i32 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=32" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i32 types.
+func.func @emulate_fptoui(%arg: f64) -> i64 {
+ %res = arith.fptoui %arg : f64 to i64
+ return %res : i64
+}
+
+func.func @check_fptoui(%arg : f64) -> () {
+ %res = func.call @emulate_fptoui(%arg) : (f64) -> (i64)
+ vector.print %res : i64
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0.0 : f64
+ %cst1 = arith.constant 1.0 : f64
+ %cst1_5 = arith.constant 1.5 : f64
+
+ %cstpow20 = arith.constant 1048576.0 : f64
+ %cst_i32_max = arith.constant 4294967295.0 : f64
+ %cst_i32_overflow = arith.constant 4294967296.0 : f64
+
+
+ %cstpow40 = arith.constant 1099511627776.0 : f64
+ %cst_pow40ppow20 = arith.constant 1099512676352.0 : f64
+
+ %cst_nzero = arith.constant 0x8000000000000000 : f64
+
+ // CHECK: 0
+ func.call @check_fptoui(%cst0) : (f64) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_fptoui(%cst1) : (f64) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_fptoui(%cst1_5) : (f64) -> ()
+ // CHECK-NEXT: 1048576
+ func.call @check_fptoui(%cstpow20) : (f64) -> ()
+ // CHECK-NEXT: 4294967295
+ func.call @check_fptoui(%cst_i32_max) : (f64) -> ()
+ // CHECK-NEXT: 4294967296
+ func.call @check_fptoui(%cst_i32_overflow) : (f64) -> ()
+ // CHECK-NEXT: 1099511627776
+ func.call @check_fptoui(%cstpow40) : (f64) -> ()
+ // CHECK-NEXT: 1099512676352
+ func.call @check_fptoui(%cst_pow40ppow20) : (f64) -> ()
+ // CHECK-NEXT: 0
+ func.call @check_fptoui(%cst_nzero) : (f64) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list