[Mlir-commits] [mlir] [mlir][arith] wide integer emulation support for fpto*i ops (PR #132375)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 27 09:30:05 PDT 2025


https://github.com/egebeysel updated https://github.com/llvm/llvm-project/pull/132375

>From 94f474a4bbfc224d13743498977c2c026d4bb9e7 Mon Sep 17 00:00:00 2001
From: Ege Beysel <beysel at roofline.ai>
Date: Thu, 27 Mar 2025 14:00:10 +0100
Subject: [PATCH 1/2] [mlir][arith] add wide integer emulation support for subi
 & update sitofp to use it

Signed-off-by: Ege Beysel <beysel at roofline.ai>
---
 .../Arith/Transforms/EmulateWideInt.cpp       |  56 ++++++++--
 mlir/test/Dialect/Arith/emulate-wide-int.mlir |  55 +++++++--
 .../CPU/test-wide-int-emulation-subi-i32.mlir | 104 ++++++++++++++++++
 3 files changed, 196 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir

diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..3226b5d99114a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -866,6 +866,46 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertSubI
+//===----------------------------------------------------------------------===//
+
+struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {}", op.getType()));
+
+    Type newElemTy = reduceInnermostDim(newTy);
+
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    auto [rhsElem0, rhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+    // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
+    // CARRY is 1 or 0.
+    Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+    // We have a carry if lhsElem0 < rhsElem0.
+    Value carry0 = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+    Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+
+    Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
+    Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+
+    Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertSIToFP
 //===----------------------------------------------------------------------===//
@@ -885,22 +925,16 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
       return rewriter.notifyMatchFailure(
           loc, llvm::formatv("unsupported type: {0}", oldTy));
 
-    unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
     Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
-    Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
-    Value allOnesCst = createScalarOrSplatConstant(
-        rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
 
     // To avoid operating on very large unsigned numbers, perform the
     // conversion on the absolute value. Then, decide whether to negate the
-    // result or not based on that sign bit. We assume two's complement and
-    // implement negation by flipping all bits and adding 1.
-    // Note that this relies on the the other conversion patterns to legalize
-    // created ops and narrow the bit widths.
+    // result or not based on that sign bit. We implement negation by
+    // subtracting from zero. Note that this relies on the the other conversion
+    // patterns to legalize created ops and narrow the bit widths.
     Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                  in, zeroCst);
-    Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
-    Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+    Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
     Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
 
     Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
@@ -1139,7 +1173,7 @@ void arith::populateArithWideIntEmulationPatterns(
       ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
       ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
       ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
-      ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
+      ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
       // Bitwise binary ops.
       ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
       ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ed08779c10266..52da80ce26a73 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
     return %x : vector<4xi64>
 }
 
+// CHECK-LABEL: func @subi_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[SUB_L:%.+]]  = arith.subi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[ULT:%.+]]    = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[ULT]] : i1 to i32
+// CHECK-NEXT:    [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32
+// CHECK-NEXT:    [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @subi_scalar(%a : i64, %b : i64) -> i64 {
+    %x = arith.subi %a, %b : i64
+    return %x : i64
+}
+
+// CHECK-LABEL: func @subi_vector
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_L:%.+]]  = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT:    [[ULT:%.+]]    = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK:         [[INS0:%.+]]   = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<4x2xi32>
+func.func @subi_vector(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+    %x = arith.subi %a, %b : vector<4xi64>
+    return %x : vector<4xi64>
+}
+
 // CHECK-LABEL: func.func @cmpi_eq_scalar
 // CHECK-SAME:    ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>)
 // CHECK-NEXT:    [[LHSLOW:%.+]]  = vector.extract [[LHS]][0] : i32 from vector<2xi32>
@@ -967,11 +1005,12 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
 
 // CHECK-LABEL: func @sitofp_i64_f64
 // CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> f64
-// CHECK:         [[VONES:%.+]]  = arith.constant dense<-1> : vector<2xi32>
-// CHECK:         [[ONES1:%.+]]  = vector.extract [[VONES]][0] : i32 from vector<2xi32>
-// CHECK-NEXT:    [[ONES2:%.+]]  = vector.extract [[VONES]][1] : i32 from vector<2xi32>
-// CHECK:                          arith.xori {{%.+}}, [[ONES1]] : i32
-// CHECK-NEXT:                     arith.xori {{%.+}}, [[ONES2]] : i32
+// CHECK:         [[VZERO:%.+]]  = arith.constant dense<0> : vector<2xi32>
+// CHECK:                          vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK:         [[ZERO1:%.+]]  = vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[ZERO2:%.+]]  = vector.extract [[VZERO]][1] : i32 from vector<2xi32>
+// CHECK:                          arith.subi [[ZERO1]], {{%.+}} : i32
+// CHECK:                          arith.subi [[ZERO2]], {{%.+}} : i32
 // CHECK:         [[CST0:%.+]]   = arith.constant 0 : i32
 // CHECK:         [[HIEQ0:%.+]]  = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
 // CHECK-NEXT:    [[LOWFP:%.+]]  = arith.uitofp [[LOW:%.+]] : i32 to f64
@@ -990,9 +1029,9 @@ func.func @sitofp_i64_f64(%a : i64) -> f64 {
 
 // CHECK-LABEL: func @sitofp_i64_f64_vector
 // CHECK-SAME:    ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
-// CHECK:         [[VONES:%.+]]  = arith.constant dense<-1> : vector<3x2xi32>
-// CHECK:                          arith.xori
-// CHECK-NEXT:                     arith.xori
+// CHECK:         [[VZERO:%.+]]  = arith.constant dense<0> : vector<3x2xi32>
+// CHECK:                          arith.subi
+// CHECK:                          arith.subi
 // CHECK:         [[HIEQ0:%.+]]  = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
 // CHECK-NEXT:    [[LOWFP:%.+]]  = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
 // CHECK-NEXT:    [[HIFP:%.+]]   = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
new file mode 100644
index 0000000000000..63d2c941c48e7
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
@@ -0,0 +1,104 @@
+// Ops in this function will be emulated using i16 types.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 {
+  %res = arith.subi %arg, %arg0 : i32
+  return %res : i32
+}
+
+func.func @check_subi(%arg : i32, %arg0 : i32) -> () {
+  %res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32)
+  vector.print %res : i32
+  return
+}
+
+func.func @entry() {
+  %lhs1 = arith.constant 1 : i32
+  %rhs1 = arith.constant 2 : i32
+
+  // CHECK:       -1
+  func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> ()
+  // CHECK-NEXT:  1
+  func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> ()
+  
+  %lhs2 = arith.constant 1 : i32
+  %rhs2 = arith.constant -2 : i32
+
+  // CHECK-NEXT:  3
+  func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> ()
+  // CHECK-NEXT:  -3
+  func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> ()
+  
+  %lhs3 = arith.constant -1 : i32
+  %rhs3 = arith.constant -2 : i32
+
+  // CHECK-NEXT:  1
+  func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> ()
+  // CHECK-NEXT:  -1
+  func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> ()
+  
+  // Overflow from the upper/lower part.
+  %lhs4 = arith.constant 131074 : i32
+  %rhs4 = arith.constant 3 : i32
+
+  // CHECK-NEXT:  131071
+  func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> ()
+  // CHECK-NEXT:  -131071
+  func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> ()
+
+  // Overflow in both parts.
+  %lhs5 = arith.constant 16385027 : i32 
+  %rhs5 = arith.constant 16450564 : i32
+
+  // CHECK-NEXT:  -65537
+  func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> ()
+  // CHECK-NEXT:  65537
+  func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> ()
+
+  %lhs6 = arith.constant 65536 : i32 
+  %rhs6 = arith.constant 1 : i32
+
+  // CHECK-NEXT:  65535
+  func.call @check_subi(%lhs6, %rhs6) : (i32, i32) -> ()
+  // CHECK-NEXT:  -65535
+  func.call @check_subi(%rhs6, %lhs6) : (i32, i32) -> ()
+
+  // Max/Min (un)signed integers.
+  %sintmax = arith.constant 2147483647 : i32 
+  %sintmin = arith.constant -2147483648 : i32
+  %uintmax = arith.constant -1 : i32
+  %uintmin = arith.constant 0 : i32
+  %cst1 = arith.constant 1 : i32
+
+  // CHECK-NEXT:  -1
+  func.call @check_subi(%sintmax, %sintmin) : (i32, i32) -> ()
+  // CHECK-NEXT:  1
+  func.call @check_subi(%sintmin, %sintmax) : (i32, i32) -> ()
+  // CHECK-NEXT:  2147483647
+  func.call @check_subi(%sintmin, %cst1) : (i32, i32) -> ()
+  // CHECK-NEXT:  -2147483648
+  func.call @check_subi(%sintmax, %uintmax) : (i32, i32) -> ()
+  // CHECK-NEXT:  -2
+  func.call @check_subi(%uintmax, %cst1) : (i32, i32) -> ()
+  // CHECK-NEXT:  0
+  func.call @check_subi(%uintmax, %uintmax) : (i32, i32) -> ()
+  // CHECK-NEXT:  -1
+  func.call @check_subi(%uintmin, %cst1) : (i32, i32) -> ()
+  // CHECK-NEXT:  1
+  func.call @check_subi(%uintmin, %uintmax) : (i32, i32) -> ()
+  
+
+  return
+}

>From 3afc66205f1fa9af4dfaa5addd6f4e9d8fc171c2 Mon Sep 17 00:00:00 2001
From: Ege Beysel <beysel at roofline.ai>
Date: Thu, 20 Mar 2025 18:18:18 +0100
Subject: [PATCH 2/2] [mlir][arith] add wide integer emulation support for
 fpto*i ops

Signed-off-by: Ege Beysel <beysel at roofline.ai>
---
 .../Arith/Transforms/EmulateWideInt.cpp       | 128 +++++++++++++++++-
 mlir/test/Dialect/Arith/emulate-wide-int.mlir | 109 +++++++++++++++
 .../test-wide-int-emulation-fptosi-i64.mlir   |  89 ++++++++++++
 .../test-wide-int-emulation-fptoui-i64.mlir   |  64 +++++++++
 4 files changed, 389 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
 create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir

diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 3226b5d99114a..bd3f53955f0fa 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
@@ -1008,6 +1009,130 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertFPToSI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Get the input float type.
+    Value inFp = adaptor.getIn();
+    Type fpTy = inFp.getType();
+
+    Type intTy = op.getType();
+
+    auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", intTy));
+
+    // Work on the absolute value and then convert the result to signed integer.
+    // Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e.
+    // if the fp is representable in signed i2N, emits the correct result.
+    // Else, the result is UB.
+
+    TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
+    Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+
+    Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0);
+
+    // Get the absolute value. One could have used math.absf here, but that
+    // introduces an extra dependency.
+    Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+                                                 inFp, zeroCst);
+    Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
+
+    Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
+
+    // Defer the absolute value to fptoui.
+    Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
+
+    // Negate the value if < 0 .
+    Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res);
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertFPToUI
+//===----------------------------------------------------------------------===//
+
+struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Get the input float type.
+    Value inFp = adaptor.getIn();
+    Type fpTy = inFp.getType();
+
+    Type intTy = op.getType();
+    auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", intTy));
+    unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+    Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
+    if (auto vecType = dyn_cast<VectorType>(fpTy))
+      newHalfType = VectorType::get(vecType.getShape(), newHalfType);
+
+    // The resulting integer has the upper part and the lower part.
+    // This would be interpreted as 2^N * high + low, where N is the bitwidth.
+    // Therefore, to calculate the higher part, we emit resHigh =
+    // fptoui(fp/2^N). For the lower part, we emit fptoui(fp - resHigh * 2^N).
+    // The special cases of overflows including +-inf, NaNs and negative numbers
+    // are UB.
+
+    const llvm::fltSemantics &fSemantics =
+        cast<FloatType>(getElementTypeOrSelf(fpTy)).getFloatSemantics();
+
+    auto powBitwidth = llvm::APFloat(fSemantics);
+    // If the integer does not fit the floating point number, we set the
+    // powBitwidth to inf. This ensures that the upper part is set
+    // correctly to 0. The opStatus inexact here only occurs when we have an
+    // overflow, since the number is always a power of two.
+    if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
+                                     false, llvm::RoundingMode::TowardZero) ==
+        llvm::detail::opStatus::opInexact)
+      powBitwidth = llvm::APFloat::getInf(fSemantics);
+
+    TypedAttr powBitwidthAttr =
+        FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
+    if (auto vecType = dyn_cast<VectorType>(fpTy))
+      powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
+    Value powBitwidthFloatCst =
+        rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
+
+    Value fpDivPowBitwidth =
+        rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
+    Value resHigh =
+        rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
+    // Calculate fp - resHigh * 2^N by getting the remainder of the division
+    Value remainder =
+        rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
+    Value resLow =
+        rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
+
+    Value high = appendX1Dim(rewriter, loc, resHigh);
+    Value low = appendX1Dim(rewriter, loc, resLow);
+
+    Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertTruncI
 //===----------------------------------------------------------------------===//
@@ -1184,5 +1309,6 @@ void arith::populateArithWideIntEmulationPatterns(
       ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
       ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
       ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
-      ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
+      ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
+      typeConverter, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 52da80ce26a73..936050cddb676 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -1046,3 +1046,112 @@ func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
     %r = arith.sitofp %a : vector<3xi64> to vector<3xf64>
     return %r : vector<3xf64>
 }
+
+// CHECK-LABEL:   func @fptoui_i64_f64
+// CHECK-SAME:      ([[ARG:%.+]]: f64) -> vector<2xi32>
+// CHECK-NEXT:      [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT:      [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : f64
+// CHECK-NEXT:      [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
+// CHECK-NEXT:      [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : f64
+// CHECK-NEXT:      [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
+// CHECK:           %{{.+}} = vector.insert [[LOWHALF]], %{{.+}} [0]
+// CHECK-NEXT:      [[RESVEC:%.+]] = vector.insert [[HIGHHALF]], %{{.+}} [1]
+// CHECK:           return [[RESVEC]] : vector<2xi32>
+func.func @fptoui_i64_f64(%a : f64) -> i64 {
+    %r = arith.fptoui %a : f64 to i64
+    return %r : i64
+}
+
+// CHECK-LABEL:   func @fptoui_i64_f64_vector
+// CHECK-SAME:      ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
+// CHECK-NEXT:      [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT:      [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : vector<3xf64>
+// CHECK-NEXT:      [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT:      [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : vector<3xf64>
+// CHECK-NEXT:      [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
+// CHECK-DAG:       [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-DAG:       [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK:           %{{.+}} = vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]}
+// CHECK-NEXT:      [[RESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]], %{{.+}} {offsets = [0, 1], strides = [1, 1]}
+// CHECK:           return [[RESVEC]] : vector<3x2xi32>
+func.func @fptoui_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
+    %r = arith.fptoui %a : vector<3xf64> to vector<3xi64>
+    return %r : vector<3xi64>
+}
+
+// This generates lines that are already verified by other patterns.
+// We do not re-verify these and just check for the wrapper around fptoui by following its low part.
+// CHECK-LABEL:   func @fptosi_i64_f64
+// CHECK-SAME:      ([[ARG:%.+]]: f64) -> vector<2xi32>
+// CHECK:           [[ZEROCST:%.+]] = arith.constant 0.000000e+00 : f64
+// CHECK:           [[ZEROCSTINT:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:      [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : f64
+// CHECK-NEXT:      [[NEGATED:%.+]] = arith.negf [[ARG]] : f64
+// CHECK-NEXT:      [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : f64
+// CHECK-NEXT:      [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT:      [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : f64
+// CHECK-NEXT:      [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
+// CHECK-NEXT:      [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : f64
+// CHECK-NEXT:      [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
+// CHECK:           vector.insert [[LOWHALF]], %{{.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:      [[FPTOUIRESVEC:%.+]] = vector.insert [[HIGHHALF]]
+// CHECK:           [[ZEROCSTINTHALF:%.+]] = vector.extract [[ZEROCSTINT]][0] : i32 from vector<2xi32>
+// CHECK:           [[SUB:%.+]] = arith.subi [[ZEROCSTINTHALF]], %{{.+}} : i32
+// CHECK-NEXT:      arith.cmpi ult, [[ZEROCSTINTHALF]], %{{.+}} : i32
+// CHECK-NEXT:      arith.extui
+// CHECK-NEXT:      arith.subi
+// CHECK-NEXT:      arith.subi
+// CHECK:           vector.insert [[SUB]]
+// CHECK:           [[SUBVEC:%.+]] = vector.insert
+// CHECK:           [[SUB:%.+]] = vector.extract [[SUBVEC]][0] : i32 from vector<2xi32>
+// CHECK:           [[LOWRES:%.+]] = vector.extract [[FPTOUIRESVEC]][0] : i32 from vector<2xi32>
+// CHECK:           [[ABSRES:%.+]] = arith.select [[ISNEGATIVE]], [[SUB]], [[LOWRES]] : i32
+// CHECK-NEXT:      arith.select [[ISNEGATIVE]]
+// CHECK:           vector.insert [[ABSRES]]
+// CHECK-NEXT:      [[ABSRESVEC:%.+]] = vector.insert
+// CHECK-NEXT:      return [[ABSRESVEC]] : vector<2xi32>
+func.func @fptosi_i64_f64(%a : f64) -> i64 {
+    %r = arith.fptosi %a : f64 to i64
+    return %r : i64
+}
+
+// Same as the non-vector one, we don't re-verify.
+// CHECK-LABEL:   func @fptosi_i64_f64_vector
+// CHECK-SAME:      ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
+// CHECK-NEXT:      [[ZEROCST:%.+]] = arith.constant dense<0.000000e+00> : vector<3xf64>
+// CHECK-NEXT:      [[ZEROCSTINT:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT:      [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : vector<3xf64>
+// CHECK-NEXT:      [[NEGATED:%.+]] = arith.negf [[ARG]] : vector<3xf64>
+// CHECK-NEXT:      [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT:      [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT:      [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : vector<3xf64>
+// CHECK-NEXT:      [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT:      [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : vector<3xf64>
+// CHECK-NEXT:      [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
+// CHECK-NEXT:      [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-NEXT:      [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
+// CHECK:           vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT:      [[FPTOUIRESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]]
+// CHECK:           [[ZEROCSTINTHALF:%.+]] = vector.extract_strided_slice [[ZEROCSTINT]]
+// CHECK-SAME:      {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK:           [[SUB:%.+]] = arith.subi [[ZEROCSTINTHALF]], %{{.+}} : vector<3x1xi32>
+// CHECK-NEXT:      arith.cmpi ult, [[ZEROCSTINTHALF]], %{{.+}} : vector<3x1xi32>
+// CHECK-NEXT:      arith.extui
+// CHECK-NEXT:      arith.subi
+// CHECK-NEXT:      arith.subi
+// CHECK:           vector.insert_strided_slice [[SUB]]
+// CHECK-NEXT:      [[SUBVEC:%.+]] = vector.insert_strided_slice
+// CHECK:           [[SUB:%.+]] = vector.extract_strided_slice [[SUBVEC]]
+// CHECK-SAME:      {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK:           [[LOWRES:%.+]] = vector.extract_strided_slice [[FPTOUIRESVEC]]
+// CHECK-SAME:      {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK:           [[ISNEGATIVEX1:%.+]] = vector.shape_cast [[ISNEGATIVE]] : vector<3xi1> to vector<3x1xi1>
+// CHECK:           [[ABSRES:%.+]] = arith.select [[ISNEGATIVEX1]], [[SUB]], [[LOWRES]] : vector<3x1xi1>, vector<3x1xi32>
+// CHECK-NEXT:      arith.select [[ISNEGATIVEX1]]
+// CHECK:           vector.insert_strided_slice [[ABSRES]]
+// CHECK-NEXT:      [[ABSRESVEC:%.+]] = vector.insert_strided_slice
+// CHECK-NEXT:      return [[ABSRESVEC]] : vector<3x2xi32>
+func.func @fptosi_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
+    %r = arith.fptosi %a : vector<3xf64> to vector<3xi64>
+    return %r : vector<3xi64>
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
new file mode 100644
index 0000000000000..d93b834c8f919
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptosi-i64.mlir
@@ -0,0 +1,89 @@
+// Check that the wide integer `arith.fptosi` emulation produces the same result as wide
+// `arith.fptosi`. Emulate i64 ops with i32 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=32" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i32 types.
+func.func @emulate_fptosi(%arg: f64) -> i64 {
+  %res = arith.fptosi %arg : f64 to i64
+  return %res : i64
+}
+
+func.func @check_fptosi(%arg : f64) -> () {
+  %res = func.call @emulate_fptosi(%arg) : (f64) -> (i64)
+  vector.print %res : i64
+  return
+}
+
+func.func @entry() {
+  %cst0 = arith.constant 0.0 : f64
+  %cst_nzero = arith.constant 0x8000000000000000 : f64
+  %cst1 = arith.constant 1.0 : f64
+  %cst_n1 = arith.constant -1.0 : f64
+  %cst_n1_5 = arith.constant -1.5 : f64
+
+  %cstpow20 = arith.constant 1048576.0 : f64
+  %cstnpow20 = arith.constant -1048576.0 : f64
+  
+  %cst_i32_max = arith.constant 4294967295.0 : f64
+  %cst_i32_min = arith.constant -4294967296.0 : f64
+  %cst_i32_overflow = arith.constant 4294967296.0 : f64
+  %cst_i32_noverflow = arith.constant -4294967297.0 : f64
+
+
+  %cstpow40 = arith.constant 1099511627776.0 : f64
+  %cstnpow40 = arith.constant -1099511627776.0 : f64
+  %cst_pow40ppow20 = arith.constant 1099512676352.0 : f64
+  %cst_npow40ppow20 = arith.constant -1099512676352.0 : f64
+
+  %cst_max = arith.constant 9007199254740992.0
+  %cst_min = arith.constant -9007199254740992.0
+  
+  // CHECK:         0
+  func.call @check_fptosi(%cst0) : (f64) -> ()
+  // CHECK-NEXT:    0
+  func.call @check_fptosi(%cst_nzero) : (f64) -> ()
+  // CHECK-NEXT:    1
+  func.call @check_fptosi(%cst1) : (f64) -> ()
+  // CHECK-NEXT:    -1
+  func.call @check_fptosi(%cst_n1) : (f64) -> ()
+  // CHECK-NEXT:    -1
+  func.call @check_fptosi(%cst_n1_5) : (f64) -> ()
+  // CHECK-NEXT:    1048576
+  func.call @check_fptosi(%cstpow20) : (f64) -> ()
+  // CHECK-NEXT:    -1048576
+  func.call @check_fptosi(%cstnpow20) : (f64) -> ()
+  // CHECK-NEXT:    4294967295
+  func.call @check_fptosi(%cst_i32_max) : (f64) -> ()
+  // CHECK-NEXT:    -4294967296
+  func.call @check_fptosi(%cst_i32_min) : (f64) -> ()
+  // CHECK-NEXT:    4294967296
+  func.call @check_fptosi(%cst_i32_overflow) : (f64) -> ()
+  // CHECK-NEXT:    -4294967297
+  func.call @check_fptosi(%cst_i32_noverflow) : (f64) -> ()
+  // CHECK-NEXT:    1099511627776
+  func.call @check_fptosi(%cstpow40) : (f64) -> ()
+  // CHECK-NEXT:    -1099511627776
+  func.call @check_fptosi(%cstnpow40) : (f64) -> ()
+  // CHECK-NEXT:    1099512676352
+  func.call @check_fptosi(%cst_pow40ppow20) : (f64) -> ()
+  // CHECK-NEXT:    -1099512676352
+  func.call @check_fptosi(%cst_npow40ppow20) : (f64) -> ()
+  // CHECK-NEXT:    9007199254740992
+  func.call @check_fptosi(%cst_max) : (f64) -> ()
+  // CHECK-NEXT:    -9007199254740992
+  func.call @check_fptosi(%cst_min) : (f64) -> ()
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir
new file mode 100644
index 0000000000000..81283ee9fdfd8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-fptoui-i64.mlir
@@ -0,0 +1,64 @@
+// Check that the wide integer `arith.fptoui` emulation produces the same result as wide
+// `arith.fptoui`. Emulate i64 ops with i32 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=32" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i32 types.
+func.func @emulate_fptoui(%arg: f64) -> i64 {
+  %res = arith.fptoui %arg : f64 to i64
+  return %res : i64
+}
+
+func.func @check_fptoui(%arg : f64) -> () {
+  %res = func.call @emulate_fptoui(%arg) : (f64) -> (i64)
+  vector.print %res : i64
+  return
+}
+
+func.func @entry() {
+  %cst0 = arith.constant 0.0 : f64
+  %cst1 = arith.constant 1.0 : f64
+  %cst1_5 = arith.constant 1.5 : f64
+
+  %cstpow20 = arith.constant 1048576.0 : f64
+  %cst_i32_max = arith.constant 4294967295.0 : f64
+  %cst_i32_overflow = arith.constant 4294967296.0 : f64
+
+
+  %cstpow40 = arith.constant 1099511627776.0 : f64
+  %cst_pow40ppow20 = arith.constant 1099512676352.0 : f64
+
+  %cst_nzero = arith.constant 0x8000000000000000 : f64
+  
+  // CHECK:         0
+  func.call @check_fptoui(%cst0) : (f64) -> ()
+  // CHECK-NEXT:    1
+  func.call @check_fptoui(%cst1) : (f64) -> ()
+  // CHECK-NEXT:    1
+  func.call @check_fptoui(%cst1_5) : (f64) -> ()
+  // CHECK-NEXT:    1048576
+  func.call @check_fptoui(%cstpow20) : (f64) -> ()
+  // CHECK-NEXT:    4294967295
+  func.call @check_fptoui(%cst_i32_max) : (f64) -> ()
+  // CHECK-NEXT:    4294967296
+  func.call @check_fptoui(%cst_i32_overflow) : (f64) -> ()
+  // CHECK-NEXT:    1099511627776
+  func.call @check_fptoui(%cstpow40) : (f64) -> ()
+  // CHECK-NEXT:    1099512676352
+  func.call @check_fptoui(%cst_pow40ppow20) : (f64) -> ()
+  // CHECK-NEXT:    0
+  func.call @check_fptoui(%cst_nzero) : (f64) -> ()
+
+  return
+}



More information about the Mlir-commits mailing list