[llvm] [InstCombine] Fold extended add/sub of the same type (PR #185259)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 8 00:09:25 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: suoyuan (suoyuan666)

<details>
<summary>Changes</summary>

Fixes #<!-- -->185244

```c
uint16_t sum = (uint16_t)a + (uint16_t)b;
return (sum > 255) ? 255 : (uint8_t)sum;
```

The code above will compile into the following LLVM IR:

```llvm
%3 = zext i8 %0 to i16
%4 = zext i8 %1 to i16
%5 = add nuw nsw i16 %4, %3
%6 = tail call i16 @<!-- -->llvm.umin.i16(i16 %5, i16 255)
%7 = trunc nuw i16 %6 to i8
ret i8 %7
```

However, we can actually use `llvm.usub.sat` directly, which will facilitate subsequent optimizations.

I actually optimized LLVM IR for similar subtraction operations, but I didn't expect the C code that could be exported as this LLVM IR would look like.

```bash
$ cmake --build build/ -t check-llvm-transforms-instcombine
[0/1] Running lit suite /home/zuos/git_repo/llvm-project/llvm/test/Transforms/InstCombine

Testing Time: 4.54s

Total Discovered Tests: 1791
  Unsupported:  110 (6.14%)
  Passed     : 1681 (93.86%)

$ cmake --build build/ -t check-llvm
[0/1] Running the LLVM regression tests

Testing Time: 861.13s

Total Discovered Tests: 72369
  Skipped          :   409 (0.57%)
  Unsupported      : 30572 (42.24%)
  Passed           : 41320 (57.10%)
  Expectedly Failed:    68 (0.09%)
```

I'm a beginner with LLVM, so there might be some issues with symbol naming or not using the most appropriate API. Please point them to help me, thanks. :)

---
Full diff: https://github.com/llvm/llvm-project/pull/185259.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+22) 
- (modified) llvm/test/Transforms/InstCombine/add.ll (+30) 
- (modified) llvm/test/Transforms/InstCombine/sub.ll (+30) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 2f3c9c6a083bd..ad8e19653c410 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1081,6 +1081,28 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
 
   Value *A, *B;
   Constant *C;
+  ConstantInt *CInt;
+
+  if (match(Src, m_UMin(m_Add(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))),
+                        m_ConstantInt(CInt)))) {
+    // trunc (umin(zext(a) + zext(b), MAX)) --> uadd.sat(a, b)
+    if (A->getType() == DestTy && B->getType() == DestTy &&
+        APInt::isSameValue(CInt->getValue(), APInt::getMaxValue(DestWidth).zext(
+                                                 CInt->getBitWidth()))) {
+      return replaceInstUsesWith(
+          Trunc, Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, A, B));
+    }
+  }
+
+  if (match(Src,
+            m_SMax(m_Sub(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))), m_Zero()))) {
+    // trunc(smax(zext(a) - zext(b), 0)) --> usub.sat(a, b)
+    if (A->getType() == DestTy && B->getType() == DestTy) {
+      return replaceInstUsesWith(
+          Trunc, Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B));
+    }
+  }
+
   if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) {
     unsigned AWidth = A->getType()->getScalarSizeInBits();
     unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index aa68dfb540064..b887b136e5a1c 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -4532,6 +4532,36 @@ define <2 x i32> @ceil_div_vec_multi_use(<2 x i32> range(i32 0, 1000) %x) {
   ret <2 x i32> %r
 }
 
+define i8 @fold_to_uadd_sat_with_same_type(i8 %a, i8 %b) {
+; CHECK-LABEL: @fold_to_uadd_sat_with_same_type(
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %za = zext i8 %a to i16
+  %zb = zext i8 %b to i16
+  %zsum = add nuw nsw i16 %zb, %za
+  %cmp = call i16 @llvm.umin.i16(i16 %zsum, i16 255)
+  %r = trunc nuw i16 %cmp to i8
+  ret i8 %r
+}
+
+define i16 @cannot_fold_to_uadd_sat_with_different_type(i8 %a, i16 %b) {
+; CHECK-LABEL: @cannot_fold_to_uadd_sat_with_different_type(
+; CHECK-NEXT:    [[ZA:%.*]] = zext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[ZB:%.*]] = zext i16 [[B:%.*]] to i32
+; CHECK-NEXT:    [[ZSUM:%.*]] = add nuw nsw i32 [[ZB]], [[ZA]]
+; CHECK-NEXT:    [[CMP:%.*]] = call i32 @llvm.umin.i32(i32 [[ZSUM]], i32 65535)
+; CHECK-NEXT:    [[R:%.*]] = trunc nuw i32 [[CMP]] to i16
+; CHECK-NEXT:    ret i16 [[R]]
+;
+  %za = zext i8 %a to i32
+  %zb = zext i16 %b to i32
+  %zsum = add nuw nsw i32 %zb, %za
+  %cmp = call i32 @llvm.umin.i32(i32 %zsum, i32 65535)
+  %r = trunc nuw i32 %cmp to i16
+  ret i16 %r
+}
+
 declare void @use_i32(i32)
 declare void @use_vec(<2 x i32>)
 declare void @fake_func(i32)
diff --git a/llvm/test/Transforms/InstCombine/sub.ll b/llvm/test/Transforms/InstCombine/sub.ll
index 439b59946fac1..f7dc6e41b30c5 100644
--- a/llvm/test/Transforms/InstCombine/sub.ll
+++ b/llvm/test/Transforms/InstCombine/sub.ll
@@ -2863,3 +2863,33 @@ entry:
   %and = and i32 %sub, 127
   ret i32 %and
 }
+
+define i8 @fold_to_usub_sat_with_same_type(i8 %a, i8 %b) {
+; CHECK-LABEL: @fold_to_usub_sat_with_same_type(
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %za = zext i8 %a to i16
+  %zb = zext i8 %b to i16
+  %zsub = sub nsw i16 %za, %zb
+  %cmp = call i16 @llvm.smax.i16(i16 %zsub, i16 0)
+  %r = trunc i16 %cmp to i8
+  ret i8 %r
+}
+
+define i8 @negative_usub_different_widths(i8 %a, i16 %b) {
+; CHECK-LABEL: @negative_usub_different_widths(
+; CHECK-NEXT:    [[ZA:%.*]] = zext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[ZB:%.*]] = zext i16 [[B:%.*]] to i32
+; CHECK-NEXT:    [[ZSUB:%.*]] = sub nsw i32 [[ZA]], [[ZB]]
+; CHECK-NEXT:    [[CMP:%.*]] = call i32 @llvm.smax.i32(i32 [[ZSUB]], i32 0)
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[CMP]] to i8
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %za = zext i8 %a to i32
+  %zb = zext i16 %b to i32
+  %zsub = sub nsw i32 %za, %zb
+  %cmp = call i32 @llvm.smax.i32(i32 %zsub, i32 0)
+  %r = trunc i32 %cmp to i8
+  ret i8 %r
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/185259


More information about the llvm-commits mailing list