[llvm] [InstCombine] Fold `sext(trunc nsw)` and `zext(trunc nuw)` (PR #88609)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Apr 13 07:17:38 PDT 2024
https://github.com/YanWQ-monad updated https://github.com/llvm/llvm-project/pull/88609
>From 17d3737524754878672ee3331d5d50bbf1793154 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 17:17:06 +0800
Subject: [PATCH 1/5] pre-commit: add tests
---
llvm/test/Transforms/InstCombine/sext.ll | 44 ++++++++++++++++++++++++
llvm/test/Transforms/InstCombine/zext.ll | 44 ++++++++++++++++++++++++
2 files changed, 88 insertions(+)
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index e3b6058ce7f806..bf20de85152909 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -423,3 +423,47 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
%s = sext i8 %a to i64
ret i64 %s
}
+
+define i32 @sext_trunc_nsw(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw(
+; CHECK-NEXT: [[C:%.*]] = trunc nsw i16 [[X:%.*]] to i8
+; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT: ret i32 [[E]]
+;
+ %c = trunc nsw i16 %x to i8
+ %e = sext i8 %c to i32
+ ret i32 %e
+}
+
+define i16 @sext_trunc_nsw_2(i32 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_2(
+; CHECK-NEXT: [[C:%.*]] = trunc nsw i32 [[X:%.*]] to i8
+; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i16
+; CHECK-NEXT: ret i16 [[E]]
+;
+ %c = trunc nsw i32 %x to i8
+ %e = sext i8 %c to i16
+ ret i16 %e
+}
+
+define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @sext_trunc_nsw_vec(
+; CHECK-NEXT: [[C:%.*]] = trunc nsw <2 x i16> [[X:%.*]] to <2 x i8>
+; CHECK-NEXT: [[E:%.*]] = sext <2 x i8> [[C]] to <2 x i32>
+; CHECK-NEXT: ret <2 x i32> [[E]]
+;
+ %c = trunc nsw <2 x i16> %x to <2 x i8>
+ %e = sext <2 x i8> %c to <2 x i32>
+ ret <2 x i32> %e
+}
+
+define i32 @sext_trunc(i16 %x) {
+; CHECK-LABEL: @sext_trunc(
+; CHECK-NEXT: [[C:%.*]] = trunc i16 [[X:%.*]] to i8
+; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT: ret i32 [[E]]
+;
+ %c = trunc i16 %x to i8
+ %e = sext i8 %c to i32
+ ret i32 %e
+}
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 88cd9c70af40d8..abaec98f777af7 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -867,3 +867,47 @@ entry:
%res = zext nneg i2 %x to i32
ret i32 %res
}
+
+define i32 @zext_trunc_nuw(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw(
+; CHECK-NEXT: [[X:%.*]] = and i16 [[X1:%.*]], 255
+; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[X]] to i32
+; CHECK-NEXT: ret i32 [[E1]]
+;
+ %c = trunc nuw i16 %x to i8
+ %e = zext i8 %c to i32
+ ret i32 %e
+}
+
+define i16 @zext_trunc_nuw_2(i32 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_2(
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
+; CHECK-NEXT: [[E:%.*]] = and i16 [[TMP1]], 255
+; CHECK-NEXT: ret i16 [[E]]
+;
+ %c = trunc nuw i32 %x to i8
+ %e = zext i8 %c to i16
+ ret i16 %e
+}
+
+define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @zext_trunc_nuw_vec(
+; CHECK-NEXT: [[X:%.*]] = and <2 x i16> [[X1:%.*]], <i16 255, i16 255>
+; CHECK-NEXT: [[E1:%.*]] = zext nneg <2 x i16> [[X]] to <2 x i32>
+; CHECK-NEXT: ret <2 x i32> [[E1]]
+;
+ %c = trunc nuw <2 x i16> %x to <2 x i8>
+ %e = zext <2 x i8> %c to <2 x i32>
+ ret <2 x i32> %e
+}
+
+define i32 @zext_trunc(i16 %x) {
+; CHECK-LABEL: @zext_trunc(
+; CHECK-NEXT: [[E:%.*]] = and i16 [[X:%.*]], 255
+; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[E]] to i32
+; CHECK-NEXT: ret i32 [[E1]]
+;
+ %c = trunc i16 %x to i8
+ %e = zext i8 %c to i32
+ ret i32 %e
+}
>From 42a0818b897b69c2eefe5864c20922f22040a486 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 17:21:03 +0800
Subject: [PATCH 2/5] InstCombine: fold `sext(trunc nsw)` and `zext(trunc nuw)`
---
.../InstCombine/InstCombineCasts.cpp | 22 ++++++++++++++++++-
llvm/test/Transforms/InstCombine/sext.ll | 9 +++-----
llvm/test/Transforms/InstCombine/zext.ll | 9 +++-----
3 files changed, 27 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 437e9b92c7032f..91c149305bb76c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1188,9 +1188,20 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
if (auto *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast
// TODO: Subsume this into EvaluateInDifferentType.
+ Value *A = CSrc->getOperand(0);
+ // If TRUNC has nuw flag, then convert directly to final type.
+ if (CSrc->hasNoUnsignedWrap()) {
+ CastInst *I =
+ CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+ if (auto *ZExt = dyn_cast<ZExtInst>(I))
+ ZExt->setNonNeg();
+ if (auto *Trunc = dyn_cast<TruncInst>(I))
+ Trunc->setHasNoUnsignedWrap(true);
+ return I;
+ }
+
// Get the sizes of the types involved. We know that the intermediate type
// will be smaller than A or C, but don't know the relation between A and C.
- Value *A = CSrc->getOperand(0);
unsigned SrcSize = A->getType()->getScalarSizeInBits();
unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
unsigned DstSize = DestTy->getScalarSizeInBits();
@@ -1467,6 +1478,15 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+ // If trunc has nsw flag, then convert directly to final type.
+ auto *CSrc = static_cast<TruncInst *>(Src);
+ if (CSrc->hasNoSignedWrap()) {
+ CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+ if (auto *Trunc = dyn_cast<TruncInst>(I))
+ Trunc->setHasNoSignedWrap(true);
+ return I;
+ }
+
// If input is a trunc from the destination type, then convert into shifts.
if (Src->hasOneUse() && X->getType() == DestTy) {
// sext (trunc X) --> ashr (shl X, C), C
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index bf20de85152909..9eae03470a4693 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -426,8 +426,7 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
define i32 @sext_trunc_nsw(i16 %x) {
; CHECK-LABEL: @sext_trunc_nsw(
-; CHECK-NEXT: [[C:%.*]] = trunc nsw i16 [[X:%.*]] to i8
-; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT: [[E:%.*]] = sext i16 [[X:%.*]] to i32
; CHECK-NEXT: ret i32 [[E]]
;
%c = trunc nsw i16 %x to i8
@@ -437,8 +436,7 @@ define i32 @sext_trunc_nsw(i16 %x) {
define i16 @sext_trunc_nsw_2(i32 %x) {
; CHECK-LABEL: @sext_trunc_nsw_2(
-; CHECK-NEXT: [[C:%.*]] = trunc nsw i32 [[X:%.*]] to i8
-; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i16
+; CHECK-NEXT: [[E:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: ret i16 [[E]]
;
%c = trunc nsw i32 %x to i8
@@ -448,8 +446,7 @@ define i16 @sext_trunc_nsw_2(i32 %x) {
define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
; CHECK-LABEL: @sext_trunc_nsw_vec(
-; CHECK-NEXT: [[C:%.*]] = trunc nsw <2 x i16> [[X:%.*]] to <2 x i8>
-; CHECK-NEXT: [[E:%.*]] = sext <2 x i8> [[C]] to <2 x i32>
+; CHECK-NEXT: [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[E]]
;
%c = trunc nsw <2 x i16> %x to <2 x i8>
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index abaec98f777af7..16e7ef143cef9e 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -870,8 +870,7 @@ entry:
define i32 @zext_trunc_nuw(i16 %x) {
; CHECK-LABEL: @zext_trunc_nuw(
-; CHECK-NEXT: [[X:%.*]] = and i16 [[X1:%.*]], 255
-; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[X]] to i32
+; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[X:%.*]] to i32
; CHECK-NEXT: ret i32 [[E1]]
;
%c = trunc nuw i16 %x to i8
@@ -881,8 +880,7 @@ define i32 @zext_trunc_nuw(i16 %x) {
define i16 @zext_trunc_nuw_2(i32 %x) {
; CHECK-LABEL: @zext_trunc_nuw_2(
-; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i16
-; CHECK-NEXT: [[E:%.*]] = and i16 [[TMP1]], 255
+; CHECK-NEXT: [[E:%.*]] = trunc nuw i32 [[X:%.*]] to i16
; CHECK-NEXT: ret i16 [[E]]
;
%c = trunc nuw i32 %x to i8
@@ -892,8 +890,7 @@ define i16 @zext_trunc_nuw_2(i32 %x) {
define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
; CHECK-LABEL: @zext_trunc_nuw_vec(
-; CHECK-NEXT: [[X:%.*]] = and <2 x i16> [[X1:%.*]], <i16 255, i16 255>
-; CHECK-NEXT: [[E1:%.*]] = zext nneg <2 x i16> [[X]] to <2 x i32>
+; CHECK-NEXT: [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[E1]]
;
%c = trunc nuw <2 x i16> %x to <2 x i8>
>From 7707f1194b277b17f01d658422d6f7e7408546a0 Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:12:50 +0800
Subject: [PATCH 3/5] test: add more tests
---
llvm/test/Transforms/InstCombine/sext.ll | 9 +++++++++
llvm/test/Transforms/InstCombine/zext.ll | 9 +++++++++
2 files changed, 18 insertions(+)
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index 9eae03470a4693..ad02594f020bc8 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -444,6 +444,15 @@ define i16 @sext_trunc_nsw_2(i32 %x) {
ret i16 %e
}
+define i16 @sext_trunc_nsw_3(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_3(
+; CHECK-NEXT: ret i16 [[E:%.*]]
+;
+ %c = trunc nsw i16 %x to i8
+ %e = sext i8 %c to i16
+ ret i16 %e
+}
+
define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
; CHECK-LABEL: @sext_trunc_nsw_vec(
; CHECK-NEXT: [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 16e7ef143cef9e..07e06e6d26a270 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -888,6 +888,15 @@ define i16 @zext_trunc_nuw_2(i32 %x) {
ret i16 %e
}
+define i16 @zext_trunc_nuw_3(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_3(
+; CHECK-NEXT: ret i16 [[E:%.*]]
+;
+ %c = trunc nuw i16 %x to i8
+ %e = zext i8 %c to i16
+ ret i16 %e
+}
+
define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
; CHECK-LABEL: @zext_trunc_nuw_vec(
; CHECK-NEXT: [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>
>From b3ed492b96efa8e69b3502abd1e84e6bec0a628f Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:15:24 +0800
Subject: [PATCH 4/5] revise code
---
llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 91c149305bb76c..d6f1e3e7e2f7db 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1189,10 +1189,9 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
// TODO: Subsume this into EvaluateInDifferentType.
Value *A = CSrc->getOperand(0);
- // If TRUNC has nuw flag, then convert directly to final type.
+ // If trunc has nuw flag, then convert directly to final type.
if (CSrc->hasNoUnsignedWrap()) {
- CastInst *I =
- CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+ CastInst *I = CastInst::CreateIntegerCast(A, DestTy, /*isSigned=*/false);
if (auto *ZExt = dyn_cast<ZExtInst>(I))
ZExt->setNonNeg();
if (auto *Trunc = dyn_cast<TruncInst>(I))
@@ -1479,9 +1478,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
// If trunc has nsw flag, then convert directly to final type.
- auto *CSrc = static_cast<TruncInst *>(Src);
+ auto *CSrc = cast<TruncInst>(Src);
if (CSrc->hasNoSignedWrap()) {
- CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+ CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /*isSigned=*/true);
if (auto *Trunc = dyn_cast<TruncInst>(I))
Trunc->setHasNoSignedWrap(true);
return I;
>From e004c573143498a0294a3b386a8031ed548dd47a Mon Sep 17 00:00:00 2001
From: YanWQ-monad <YanWQmonad at gmail.com>
Date: Sat, 13 Apr 2024 22:15:51 +0800
Subject: [PATCH 5/5] remove unused fold
---
llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d6f1e3e7e2f7db..022c2b937f1913 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1471,12 +1471,6 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
Value *X;
if (match(Src, m_Trunc(m_Value(X)))) {
- // If the input has more sign bits than bits truncated, then convert
- // directly to final type.
- unsigned XBitSize = X->getType()->getScalarSizeInBits();
- if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
- return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
-
// If trunc has nsw flag, then convert directly to final type.
auto *CSrc = cast<TruncInst>(Src);
if (CSrc->hasNoSignedWrap()) {
@@ -1497,6 +1491,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
// the logic shift to arithmetic shift and eliminate the cast to
// intermediate type:
// sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C)
+ unsigned XBitSize = X->getType()->getScalarSizeInBits();
Value *Y;
if (Src->hasOneUse() &&
match(X, m_LShr(m_Value(Y),
More information about the llvm-commits
mailing list