[Mlir-commits] [mlir] [mlir][arith] add wide integer emulation support for subi (PR #133248)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 27 06:08:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: None (egebeysel)

<details>
<summary>Changes</summary>

Adds wide integer emulation support for the `arith.subi` op. `(i2N, i2N) -> (i2N)` ops are emulated as `(vector<2xiN>, vector<2xiN>) -> (vector<2xiN>)`, just as the other emulation patterns. 

The emulation uses the following scheme:

```
resLow = lhsLow - rhsLow;      // carry = 1 if rhsLow > lhsLow
resHigh = lhsLow - carry - rhsLow;
```

---
Full diff: https://github.com/llvm/llvm-project/pull/133248.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+42-1) 
- (modified) mlir/test/Dialect/Arith/emulate-wide-int.mlir (+38) 
- (added) mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir (+71) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..1d36b751083b7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -866,6 +867,46 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertSubI
+//===----------------------------------------------------------------------===//
+
+struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
+
+    Type newElemTy = reduceInnermostDim(newTy);
+
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    auto [rhsElem0, rhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+    // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
+    // CARRY is 1 or 0.
+    Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+    // We have a carry if lhsElem0 < rhsElem0.
+    Value carry0 = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+    Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+
+    Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
+    Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+
+    Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertSIToFP
 //===----------------------------------------------------------------------===//
@@ -1139,7 +1180,7 @@ void arith::populateArithWideIntEmulationPatterns(
       ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
       ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
       ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
-      ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
+      ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
       // Bitwise binary ops.
       ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
       ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index ed08779c10266..fb2d73369b0a3 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
     return %x : vector<4xi64>
 }
 
+// CHECK-LABEL: func @subi_scalar_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : i32 from vector<2xi32>
+// CHECK-NEXT:    [[SUB_L:%.+]]  = arith.subi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[ULT:%.+]]    = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[ULT]] : i1 to i32
+// CHECK-NEXT:    [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32
+// CHECK-NEXT:    [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @subi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+    %x = arith.subi %a, %b : i64
+    return %x : i64
+}
+
+// CHECK-LABEL: func @subi_vector_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_L:%.+]]  = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT:    [[ULT:%.+]]    = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32>
+// CHECK-NEXT:    [[CARRY:%.+]]  = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32>
+// CHECK-NEXT:    [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32>
+// CHECK:         [[INS0:%.+]]   = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<4x2xi32>
+func.func @subi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
+    %x = arith.subi %a, %b : vector<4xi64>
+    return %x : vector<4xi64>
+}
+
 // CHECK-LABEL: func.func @cmpi_eq_scalar
 // CHECK-SAME:    ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>)
 // CHECK-NEXT:    [[LHSLOW:%.+]]  = vector.extract [[LHS]][0] : i32 from vector<2xi32>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
new file mode 100644
index 0000000000000..a82c8f0ea1c09
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir
@@ -0,0 +1,71 @@
+// Ops in this function will be emulated using i16 types.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 {
+  %res = arith.subi %arg, %arg0 : i32
+  return %res : i32
+}
+
+func.func @check_subi(%arg : i32, %arg0 : i32) -> () {
+  %res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32)
+  vector.print %res : i32
+  return
+}
+
+func.func @entry() {
+  %lhs1 = arith.constant 1 : i32
+  %rhs1 = arith.constant 2 : i32
+
+  // CHECK:       -1
+  func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> ()
+  // CHECK-NEXT:  1
+  func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> ()
+  
+  %lhs2 = arith.constant 1 : i32
+  %rhs2 = arith.constant -2 : i32
+
+  // CHECK-NEXT:  3
+  func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> ()
+  // CHECK-NEXT:  -3
+  func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> ()
+  
+  %lhs3 = arith.constant -1 : i32
+  %rhs3 = arith.constant -2 : i32
+
+  // CHECK-NEXT:  1
+  func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> ()
+  // CHECK-NEXT:  -1
+  func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> ()
+  
+  // Overflow from the upper/lower part
+  %lhs4 = arith.constant 131074 : i32
+  %rhs4 = arith.constant 3 : i32
+
+  // CHECK-NEXT:  131071
+  func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> ()
+  // CHECK-NEXT:  -131071
+  func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> ()
+
+  // Overflow in both parts
+  %lhs5 = arith.constant 16385027 : i32 
+  %rhs5 = arith.constant 16450564 : i32
+
+  // CHECK-NEXT:  -65537
+  func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> ()
+  // CHECK-NEXT:  65537
+  func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> ()
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/133248


More information about the Mlir-commits mailing list