[llvm] [InstCombine] Fold `sext(trunc nsw)` and `zext(trunc nuw)` (PR #88609)

via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 13 02:32:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Monad (YanWQ-monad)

<details>
<summary>Changes</summary>

Fold
- `sext (trunc nsw X to Y) to Z` to `cast (nsw) X to Z`, and
- `zext (trunc nuw X to Y) to Z` to `cast (nuw) X to Z`

Alive2 proofs:
- `sext`: https://alive2.llvm.org/ce/z/cqsk5t
- `zext`: https://alive2.llvm.org/ce/z/kdtEWb

---
Full diff: https://github.com/llvm/llvm-project/pull/88609.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+21-1) 
- (modified) llvm/test/Transforms/InstCombine/sext.ll (+41) 
- (modified) llvm/test/Transforms/InstCombine/zext.ll (+41) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 437e9b92c7032f..91c149305bb76c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1188,9 +1188,20 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
   if (auto *CSrc = dyn_cast<TruncInst>(Src)) {   // A->B->C cast
     // TODO: Subsume this into EvaluateInDifferentType.
 
+    Value *A = CSrc->getOperand(0);
+    // If TRUNC has nuw flag, then convert directly to final type.
+    if (CSrc->hasNoUnsignedWrap()) {
+      CastInst *I =
+          CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+      if (auto *ZExt = dyn_cast<ZExtInst>(I))
+        ZExt->setNonNeg();
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoUnsignedWrap(true);
+      return I;
+    }
+
     // Get the sizes of the types involved.  We know that the intermediate type
     // will be smaller than A or C, but don't know the relation between A and C.
-    Value *A = CSrc->getOperand(0);
     unsigned SrcSize = A->getType()->getScalarSizeInBits();
     unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
     unsigned DstSize = DestTy->getScalarSizeInBits();
@@ -1467,6 +1478,15 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
     if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
       return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
 
+    // If trunc has nsw flag, then convert directly to final type.
+    auto *CSrc = static_cast<TruncInst *>(Src);
+    if (CSrc->hasNoSignedWrap()) {
+      CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoSignedWrap(true);
+      return I;
+    }
+
     // If input is a trunc from the destination type, then convert into shifts.
     if (Src->hasOneUse() && X->getType() == DestTy) {
       // sext (trunc X) --> ashr (shl X, C), C
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index e3b6058ce7f806..9eae03470a4693 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -423,3 +423,44 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
   %s = sext i8 %a to i64
   ret i64 %s
 }
+
+define i32 @sext_trunc_nsw(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw(
+; CHECK-NEXT:    [[E:%.*]] = sext i16 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc nsw i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @sext_trunc_nsw_2(i32 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_2(
+; CHECK-NEXT:    [[E:%.*]] = trunc nsw i32 [[X:%.*]] to i16
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nsw i32 %x to i8
+  %e = sext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @sext_trunc_nsw_vec(
+; CHECK-NEXT:    [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E]]
+;
+  %c = trunc nsw <2 x i16> %x to <2 x i8>
+  %e = sext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @sext_trunc(i16 %x) {
+; CHECK-LABEL: @sext_trunc(
+; CHECK-NEXT:    [[C:%.*]] = trunc i16 [[X:%.*]] to i8
+; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 88cd9c70af40d8..16e7ef143cef9e 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -867,3 +867,44 @@ entry:
   %res = zext nneg i2 %x to i32
   ret i32 %res
 }
+
+define i32 @zext_trunc_nuw(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw(
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc nuw i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @zext_trunc_nuw_2(i32 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_2(
+; CHECK-NEXT:    [[E:%.*]] = trunc nuw i32 [[X:%.*]] to i16
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nuw i32 %x to i8
+  %e = zext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @zext_trunc_nuw_vec(
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E1]]
+;
+  %c = trunc nuw <2 x i16> %x to <2 x i8>
+  %e = zext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @zext_trunc(i16 %x) {
+; CHECK-LABEL: @zext_trunc(
+; CHECK-NEXT:    [[E:%.*]] = and i16 [[X:%.*]], 255
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[E]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88609


More information about the llvm-commits mailing list