[Mlir-commits] [mlir] [mlir][arith] add wide integer emulation support for subi (PR #133248)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 27 08:48:45 PDT 2025
https://github.com/egebeysel updated https://github.com/llvm/llvm-project/pull/133248
>From aba2153312dc47f1aceb904d5162c3f5152590d7 Mon Sep 17 00:00:00 2001
From: Ege Beysel <beysel at roofline.ai>
Date: Thu, 27 Mar 2025 14:00:10 +0100
Subject: [PATCH] [mlir][arith] add wide integer emulation support for subi &
update sitofp to use it
Signed-off-by: Ege Beysel <beysel at roofline.ai>
---
.../Arith/Transforms/EmulateWideInt.cpp | 56 ++++++++++---
mlir/test/Dialect/Arith/emulate-wide-int.mlir | 55 +++++++++++--
.../CPU/test-wide-int-emulation-subi-i32.mlir | 81 +++++++++++++++++++
3 files changed, 173 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..3226b5d99114a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -866,6 +866,46 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertSubI
+//===----------------------------------------------------------------------===//
+
+struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {}", op.getType()));
+
+ Type newElemTy = reduceInnermostDim(newTy);
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ auto [rhsElem0, rhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+ // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
+ // CARRY is 1 or 0.
+ Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+ // We have a carry if lhsElem0 < rhsElem0.
+ Value carry0 = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+ Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+
+ Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
+ Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+
+ Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertSIToFP
//===----------------------------------------------------------------------===//
@@ -885,22 +925,16 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", oldTy));
- unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
- Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
- Value allOnesCst = createScalarOrSplatConstant(
- rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
// To avoid operating on very large unsigned numbers, perform the
// conversion on the absolute value. Then, decide whether to negate the
- // result or not based on that sign bit. We assume two's complement and
- // implement negation by flipping all bits and adding 1.
- // Note that this relies on the the other conversion patterns to legalize
- // created ops and narrow the bit widths.
+ // result or not based on that sign bit. We implement negation by
+ // subtracting from zero. Note that this relies on the the other conversion
+ // patterns to legalize created ops and narrow the bit widths.
Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
in, zeroCst);
- Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
- Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+ Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
@@ -1139,7 +1173,7 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
- ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
+ ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
// Bitwise binary ops.
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ed08779c10266..5603d8e5064cb 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
return %x : vector<4xi64>
}
+// CHECK-LABEL: func @subi_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : i32 from vector<2xi32>
+// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : i1 to i32
+// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32
+// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @subi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.subi %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @subi_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32>
+// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<4x2xi32>
+func.func @subi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+ %x = arith.subi %a, %b : vector<4xi64>
+ return %x : vector<4xi64>
+}
+
// CHECK-LABEL: func.func @cmpi_eq_scalar
// CHECK-SAME: ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>)
// CHECK-NEXT: [[LHSLOW:%.+]] = vector.extract [[LHS]][0] : i32 from vector<2xi32>
@@ -967,11 +1005,12 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
// CHECK-LABEL: func @sitofp_i64_f64
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64
-// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32>
-// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : i32 from vector<2xi32>
-// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : i32 from vector<2xi32>
-// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32
-// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32
+// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK: vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK: [[ZERO1:%.+]] = vector.extract [[VZERO]][0] : i32 from vector<2xi32>
+// CHECK-NEXT: [[ZERO2:%.+]] = vector.extract [[VZERO]][1] : i32 from vector<2xi32>
+// CHECK: arith.subi [[ZERO1]], {{%.+}} : i32
+// CHECK: arith.subi [[ZERO2]], {{%.+}} : i32
// CHECK: [[CST0:%.+]] = arith.constant 0 : i32
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64
@@ -990,9 +1029,9 @@ func.func @sitofp_i64_f64(%a : i64) -> f64 {
// CHECK-LABEL: func @sitofp_i64_f64_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
-// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32>
-// CHECK: arith.xori
-// CHECK-NEXT: arith.xori
+// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<3x2xi32>
+// CHECK: arith.subi
+// CHECK: arith.subi
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
new file mode 100644
index 0000000000000..7f0e8fd111028
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
@@ -0,0 +1,81 @@
+// Ops in this function will be emulated using i16 types.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 {
+ %res = arith.subi %arg, %arg0 : i32
+ return %res : i32
+}
+
+func.func @check_subi(%arg : i32, %arg0 : i32) -> () {
+ %res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32)
+ vector.print %res : i32
+ return
+}
+
+func.func @entry() {
+ %lhs1 = arith.constant 1 : i32
+ %rhs1 = arith.constant 2 : i32
+
+ // CHECK: -1
+ func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> ()
+
+ %lhs2 = arith.constant 1 : i32
+ %rhs2 = arith.constant -2 : i32
+
+ // CHECK-NEXT: 3
+ func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> ()
+ // CHECK-NEXT: -3
+ func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> ()
+
+ %lhs3 = arith.constant -1 : i32
+ %rhs3 = arith.constant -2 : i32
+
+ // CHECK-NEXT: 1
+ func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> ()
+ // CHECK-NEXT: -1
+ func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> ()
+
+ // Overflow from the upper/lower part
+ %lhs4 = arith.constant 131074 : i32
+ %rhs4 = arith.constant 3 : i32
+
+ // CHECK-NEXT: 131071
+ func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> ()
+ // CHECK-NEXT: -131071
+ func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> ()
+
+ // Overflow in both parts
+ %lhs5 = arith.constant 16385027 : i32
+ %rhs5 = arith.constant 16450564 : i32
+
+ // CHECK-NEXT: -65537
+ func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> ()
+ // CHECK-NEXT: 65537
+ func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> ()
+
+ // Max/Min unsigned integers
+ %uintmax = arith.constant 2147483647 : i32
+ %uintmin = arith.constant -2147483648 : i32
+
+ // CHECK-NEXT: -1
+ func.call @check_subi(%uintmax, %uintmin) : (i32, i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_subi(%uintmin, %uintmax) : (i32, i32) -> ()
+
+
+ return
+}
More information about the Mlir-commits
mailing list