[llvm] [InstCombine] Fold `sext(trunc nsw)` and `zext(trunc nuw)` (PR #88609)

via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 13 07:17:38 PDT 2024


https://github.com/YanWQ-monad updated https://github.com/llvm/llvm-project/pull/88609

>From 17d3737524754878672ee3331d5d50bbf1793154 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 17:17:06 +0800
Subject: [PATCH 1/5] pre-commit: add tests

---
 llvm/test/Transforms/InstCombine/sext.ll | 44 ++++++++++++++++++++++++
 llvm/test/Transforms/InstCombine/zext.ll | 44 ++++++++++++++++++++++++
 2 files changed, 88 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index e3b6058ce7f806..bf20de85152909 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -423,3 +423,47 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
   %s = sext i8 %a to i64
   ret i64 %s
 }
+
+define i32 @sext_trunc_nsw(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw(
+; CHECK-NEXT:    [[C:%.*]] = trunc nsw i16 [[X:%.*]] to i8
+; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc nsw i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @sext_trunc_nsw_2(i32 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_2(
+; CHECK-NEXT:    [[C:%.*]] = trunc nsw i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i16
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nsw i32 %x to i8
+  %e = sext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @sext_trunc_nsw_vec(
+; CHECK-NEXT:    [[C:%.*]] = trunc nsw <2 x i16> [[X:%.*]] to <2 x i8>
+; CHECK-NEXT:    [[E:%.*]] = sext <2 x i8> [[C]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E]]
+;
+  %c = trunc nsw <2 x i16> %x to <2 x i8>
+  %e = sext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @sext_trunc(i16 %x) {
+; CHECK-LABEL: @sext_trunc(
+; CHECK-NEXT:    [[C:%.*]] = trunc i16 [[X:%.*]] to i8
+; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 88cd9c70af40d8..abaec98f777af7 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -867,3 +867,47 @@ entry:
   %res = zext nneg i2 %x to i32
   ret i32 %res
 }
+
+define i32 @zext_trunc_nuw(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw(
+; CHECK-NEXT:    [[X:%.*]] = and i16 [[X1:%.*]], 255
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[X]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc nuw i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @zext_trunc_nuw_2(i32 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
+; CHECK-NEXT:    [[E:%.*]] = and i16 [[TMP1]], 255
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nuw i32 %x to i8
+  %e = zext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @zext_trunc_nuw_vec(
+; CHECK-NEXT:    [[X:%.*]] = and <2 x i16> [[X1:%.*]], <i16 255, i16 255>
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E1]]
+;
+  %c = trunc nuw <2 x i16> %x to <2 x i8>
+  %e = zext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @zext_trunc(i16 %x) {
+; CHECK-LABEL: @zext_trunc(
+; CHECK-NEXT:    [[E:%.*]] = and i16 [[X:%.*]], 255
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[E]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}

>From 42a0818b897b69c2eefe5864c20922f22040a486 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 17:21:03 +0800
Subject: [PATCH 2/5] InstCombine: fold `sext(trunc nsw)` and `zext(trunc nuw)`

---
 .../InstCombine/InstCombineCasts.cpp          | 22 ++++++++++++++++++-
 llvm/test/Transforms/InstCombine/sext.ll      |  9 +++-----
 llvm/test/Transforms/InstCombine/zext.ll      |  9 +++-----
 3 files changed, 27 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 437e9b92c7032f..91c149305bb76c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1188,9 +1188,20 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
   if (auto *CSrc = dyn_cast<TruncInst>(Src)) {   // A->B->C cast
     // TODO: Subsume this into EvaluateInDifferentType.
 
+    Value *A = CSrc->getOperand(0);
+    // If TRUNC has nuw flag, then convert directly to final type.
+    if (CSrc->hasNoUnsignedWrap()) {
+      CastInst *I =
+          CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+      if (auto *ZExt = dyn_cast<ZExtInst>(I))
+        ZExt->setNonNeg();
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoUnsignedWrap(true);
+      return I;
+    }
+
     // Get the sizes of the types involved.  We know that the intermediate type
     // will be smaller than A or C, but don't know the relation between A and C.
-    Value *A = CSrc->getOperand(0);
     unsigned SrcSize = A->getType()->getScalarSizeInBits();
     unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
     unsigned DstSize = DestTy->getScalarSizeInBits();
@@ -1467,6 +1478,15 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
     if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
       return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
 
+    // If trunc has nsw flag, then convert directly to final type.
+    auto *CSrc = static_cast<TruncInst *>(Src);
+    if (CSrc->hasNoSignedWrap()) {
+      CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoSignedWrap(true);
+      return I;
+    }
+
     // If input is a trunc from the destination type, then convert into shifts.
     if (Src->hasOneUse() && X->getType() == DestTy) {
       // sext (trunc X) --> ashr (shl X, C), C
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index bf20de85152909..9eae03470a4693 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -426,8 +426,7 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
 
 define i32 @sext_trunc_nsw(i16 %x) {
 ; CHECK-LABEL: @sext_trunc_nsw(
-; CHECK-NEXT:    [[C:%.*]] = trunc nsw i16 [[X:%.*]] to i8
-; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT:    [[E:%.*]] = sext i16 [[X:%.*]] to i32
 ; CHECK-NEXT:    ret i32 [[E]]
 ;
   %c = trunc nsw i16 %x to i8
@@ -437,8 +436,7 @@ define i32 @sext_trunc_nsw(i16 %x) {
 
 define i16 @sext_trunc_nsw_2(i32 %x) {
 ; CHECK-LABEL: @sext_trunc_nsw_2(
-; CHECK-NEXT:    [[C:%.*]] = trunc nsw i32 [[X:%.*]] to i8
-; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i16
+; CHECK-NEXT:    [[E:%.*]] = trunc nsw i32 [[X:%.*]] to i16
 ; CHECK-NEXT:    ret i16 [[E]]
 ;
   %c = trunc nsw i32 %x to i8
@@ -448,8 +446,7 @@ define i16 @sext_trunc_nsw_2(i32 %x) {
 
 define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
 ; CHECK-LABEL: @sext_trunc_nsw_vec(
-; CHECK-NEXT:    [[C:%.*]] = trunc nsw <2 x i16> [[X:%.*]] to <2 x i8>
-; CHECK-NEXT:    [[E:%.*]] = sext <2 x i8> [[C]] to <2 x i32>
+; CHECK-NEXT:    [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[E]]
 ;
   %c = trunc nsw <2 x i16> %x to <2 x i8>
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index abaec98f777af7..16e7ef143cef9e 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -870,8 +870,7 @@ entry:
 
 define i32 @zext_trunc_nuw(i16 %x) {
 ; CHECK-LABEL: @zext_trunc_nuw(
-; CHECK-NEXT:    [[X:%.*]] = and i16 [[X1:%.*]], 255
-; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[X]] to i32
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[X:%.*]] to i32
 ; CHECK-NEXT:    ret i32 [[E1]]
 ;
   %c = trunc nuw i16 %x to i8
@@ -881,8 +880,7 @@ define i32 @zext_trunc_nuw(i16 %x) {
 
 define i16 @zext_trunc_nuw_2(i32 %x) {
 ; CHECK-LABEL: @zext_trunc_nuw_2(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT:    [[E:%.*]] = and i16 [[TMP1]], 255
+; CHECK-NEXT:    [[E:%.*]] = trunc nuw i32 [[X:%.*]] to i16
 ; CHECK-NEXT:    ret i16 [[E]]
 ;
   %c = trunc nuw i32 %x to i8
@@ -892,8 +890,7 @@ define i16 @zext_trunc_nuw_2(i32 %x) {
 
 define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
 ; CHECK-LABEL: @zext_trunc_nuw_vec(
-; CHECK-NEXT:    [[X:%.*]] = and <2 x i16> [[X1:%.*]], <i16 255, i16 255>
-; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X]] to <2 x i32>
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[E1]]
 ;
   %c = trunc nuw <2 x i16> %x to <2 x i8>

>From 7707f1194b277b17f01d658422d6f7e7408546a0 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:12:50 +0800
Subject: [PATCH 3/5] test: add more tests

---
 llvm/test/Transforms/InstCombine/sext.ll | 9 +++++++++
 llvm/test/Transforms/InstCombine/zext.ll | 9 +++++++++
 2 files changed, 18 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index 9eae03470a4693..ad02594f020bc8 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -444,6 +444,15 @@ define i16 @sext_trunc_nsw_2(i32 %x) {
   ret i16 %e
 }
 
+define i16 @sext_trunc_nsw_3(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_3(
+; CHECK-NEXT:    ret i16 [[E:%.*]]
+;
+  %c = trunc nsw i16 %x to i8
+  %e = sext i8 %c to i16
+  ret i16 %e
+}
+
 define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
 ; CHECK-LABEL: @sext_trunc_nsw_vec(
 ; CHECK-NEXT:    [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 16e7ef143cef9e..07e06e6d26a270 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -888,6 +888,15 @@ define i16 @zext_trunc_nuw_2(i32 %x) {
   ret i16 %e
 }
 
+define i16 @zext_trunc_nuw_3(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_3(
+; CHECK-NEXT:    ret i16 [[E:%.*]]
+;
+  %c = trunc nuw i16 %x to i8
+  %e = zext i8 %c to i16
+  ret i16 %e
+}
+
 define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
 ; CHECK-LABEL: @zext_trunc_nuw_vec(
 ; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>

>From b3ed492b96efa8e69b3502abd1e84e6bec0a628f Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:15:24 +0800
Subject: [PATCH 4/5] revise code

---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 91c149305bb76c..d6f1e3e7e2f7db 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1189,10 +1189,9 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
     // TODO: Subsume this into EvaluateInDifferentType.
 
     Value *A = CSrc->getOperand(0);
-    // If TRUNC has nuw flag, then convert directly to final type.
+    // If trunc has nuw flag, then convert directly to final type.
     if (CSrc->hasNoUnsignedWrap()) {
-      CastInst *I =
-          CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+      CastInst *I = CastInst::CreateIntegerCast(A, DestTy, /*isSigned=*/false);
       if (auto *ZExt = dyn_cast<ZExtInst>(I))
         ZExt->setNonNeg();
       if (auto *Trunc = dyn_cast<TruncInst>(I))
@@ -1479,9 +1478,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
       return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
 
     // If trunc has nsw flag, then convert directly to final type.
-    auto *CSrc = static_cast<TruncInst *>(Src);
+    auto *CSrc = cast<TruncInst>(Src);
     if (CSrc->hasNoSignedWrap()) {
-      CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+      CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /*isSigned=*/true);
       if (auto *Trunc = dyn_cast<TruncInst>(I))
         Trunc->setHasNoSignedWrap(true);
       return I;

>From e004c573143498a0294a3b386a8031ed548dd47a Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:15:51 +0800
Subject: [PATCH 5/5] remove unused fold

---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d6f1e3e7e2f7db..022c2b937f1913 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1471,12 +1471,6 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
 
   Value *X;
   if (match(Src, m_Trunc(m_Value(X)))) {
-    // If the input has more sign bits than bits truncated, then convert
-    // directly to final type.
-    unsigned XBitSize = X->getType()->getScalarSizeInBits();
-    if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
-      return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
-
     // If trunc has nsw flag, then convert directly to final type.
     auto *CSrc = cast<TruncInst>(Src);
     if (CSrc->hasNoSignedWrap()) {
@@ -1497,6 +1491,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
     // the logic shift to arithmetic shift and eliminate the cast to
     // intermediate type:
     // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C)
+    unsigned XBitSize = X->getType()->getScalarSizeInBits();
     Value *Y;
     if (Src->hasOneUse() &&
         match(X, m_LShr(m_Value(Y),



More information about the llvm-commits mailing list