[llvm] [InstCombine] Fold `ucmp/scmp(x, y) >> N` to `zext/sext(x < y)` when N is one less than the width of the result of `ucmp/scmp` (PR #104009)
Volodymyr Vasylkun via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 14 09:11:46 PDT 2024
https://github.com/Poseydon42 updated https://github.com/llvm/llvm-project/pull/104009
>From d345393126a3de19bf8045a95c63d72b67a47626 Mon Sep 17 00:00:00 2001
From: Poseydon42 <vvmposeydon at gmail.com>
Date: Wed, 14 Aug 2024 15:22:26 +0100
Subject: [PATCH 1/4] Precommit tests
---
.../InstCombine/lshr-ashr-of-uscmp.ll | 95 +++++++++++++++++++
1 file changed, 95 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
diff --git a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
new file mode 100644
index 00000000000000..1d52151dad3141
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare void @use(i8 %val)
+
+; ucmp/scmp(x, y) >> N folds to either zext(x < y) or sext(x < y)
+; if N is one less than the width of result of ucmp/scmp
+define i8 @ucmp_to_zext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+ %2 = lshr i8 %1, 7
+ ret i8 %2
+}
+
+define i8 @ucmp_to_sext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_sext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = ashr i8 [[TMP1]], 7
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+ %2 = ashr i8 %1, 7
+ ret i8 %2
+}
+
+define i8 @scmp_to_zext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_to_zext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.scmp(i32 %x, i32 %y)
+ %2 = lshr i8 %1, 7
+ ret i8 %2
+}
+
+define i8 @scmp_to_sext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_to_sext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = ashr i8 [[TMP1]], 7
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.scmp(i32 %x, i32 %y)
+ %2 = ashr i8 %1, 7
+ ret i8 %2
+}
+
+; Negative test: incorrect shift amount
+define i8 @ucmp_to_zext_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 5
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+ %2 = lshr i8 %1, 5
+ ret i8 %2
+}
+
+; Negative test: shift amount is not a constant
+define i8 @ucmp_to_zext_neg2(i32 %x, i32 %y, i8 %s) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i8 [[S:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], [[S]]
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+ %2 = lshr i8 %1, %s
+ ret i8 %2
+}
+
+; Negative test: the result of ucmp/scmp is used more than once
+define i8 @ucmp_to_zext_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: call void @use(i8 [[TMP1]])
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+ call void @use(i8 %1)
+ %2 = lshr i8 %1, 7
+ ret i8 %2
+}
>From ed68aa0ed8dff6d5e3b28cf72bcff988a3757bd8 Mon Sep 17 00:00:00 2001
From: Poseydon42 <vvmposeydon at gmail.com>
Date: Wed, 14 Aug 2024 15:23:03 +0100
Subject: [PATCH 2/4] Implement the fold
---
.../InstCombine/InstCombineShifts.cpp | 18 ++++++++++++++++++
.../InstCombine/lshr-ashr-of-uscmp.ll | 16 ++++++++--------
2 files changed, 26 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b682..74d67e5c5a9a16 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -511,6 +511,24 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1))))
return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1));
+ Instruction *CmpIntr;
+ const APInt *ShiftAmount;
+ if ((I.getOpcode() == Instruction::LShr ||
+ I.getOpcode() == Instruction::AShr) &&
+ match(Op0, m_Instruction(CmpIntr)) && CmpIntr->hasOneUse() &&
+ isa<CmpIntrinsic>(CmpIntr) && match(Op1, m_APInt(ShiftAmount)) &&
+ *ShiftAmount + 1 == Ty->getIntegerBitWidth()) {
+ Value *Cmp = Builder.CreateICmp(
+ cast<CmpIntrinsic>(CmpIntr)->isSigned() ? ICmpInst::ICMP_SLT
+ : ICmpInst::ICMP_ULT,
+ CmpIntr->getOperand(0), CmpIntr->getOperand(1));
+ Instruction *CmpExt =
+ CastInst::Create(I.getOpcode() == Instruction::LShr ? Instruction::ZExt
+ : Instruction::SExt,
+ Cmp, Ty);
+ return CmpExt;
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
index 1d52151dad3141..62043d1af0c54c 100644
--- a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
+++ b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
@@ -8,8 +8,8 @@ declare void @use(i8 %val)
define i8 @ucmp_to_zext(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @ucmp_to_zext(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext i1 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
@@ -20,8 +20,8 @@ define i8 @ucmp_to_zext(i32 %x, i32 %y) {
define i8 @ucmp_to_sext(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @ucmp_to_sext(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[TMP2:%.*]] = ashr i8 [[TMP1]], 7
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT: [[TMP2:%.*]] = sext i1 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
@@ -32,8 +32,8 @@ define i8 @ucmp_to_sext(i32 %x, i32 %y) {
define i8 @scmp_to_zext(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_to_zext(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[TMP2:%.*]] = zext i1 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = call i8 @llvm.scmp(i32 %x, i32 %y)
@@ -44,8 +44,8 @@ define i8 @scmp_to_zext(i32 %x, i32 %y) {
define i8 @scmp_to_sext(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_to_sext(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[TMP2:%.*]] = ashr i8 [[TMP1]], 7
+; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[TMP2:%.*]] = sext i1 [[TMP1]] to i8
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = call i8 @llvm.scmp(i32 %x, i32 %y)
>From 95092947853f011b76531f380a12bcba1e612cca Mon Sep 17 00:00:00 2001
From: Poseydon42 <vvmposeydon at gmail.com>
Date: Wed, 14 Aug 2024 16:48:04 +0100
Subject: [PATCH 3/4] Address review comments
---
.../InstCombine/InstCombineShifts.cpp | 21 ++++++++-----------
.../InstCombine/lshr-ashr-of-uscmp.ll | 12 +++++++++++
2 files changed, 21 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 74d67e5c5a9a16..6852a41b68c556 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -512,21 +512,18 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1));
Instruction *CmpIntr;
- const APInt *ShiftAmount;
if ((I.getOpcode() == Instruction::LShr ||
I.getOpcode() == Instruction::AShr) &&
match(Op0, m_Instruction(CmpIntr)) && CmpIntr->hasOneUse() &&
- isa<CmpIntrinsic>(CmpIntr) && match(Op1, m_APInt(ShiftAmount)) &&
- *ShiftAmount + 1 == Ty->getIntegerBitWidth()) {
- Value *Cmp = Builder.CreateICmp(
- cast<CmpIntrinsic>(CmpIntr)->isSigned() ? ICmpInst::ICMP_SLT
- : ICmpInst::ICMP_ULT,
- CmpIntr->getOperand(0), CmpIntr->getOperand(1));
- Instruction *CmpExt =
- CastInst::Create(I.getOpcode() == Instruction::LShr ? Instruction::ZExt
- : Instruction::SExt,
- Cmp, Ty);
- return CmpExt;
+ isa<CmpIntrinsic>(CmpIntr) &&
+ match(Op1, m_SpecificInt(Ty->getScalarSizeInBits() - 1))) {
+ Value *Cmp =
+ Builder.CreateICmp(cast<CmpIntrinsic>(CmpIntr)->getLTPredicate(),
+ CmpIntr->getOperand(0), CmpIntr->getOperand(1));
+ return CastInst::Create(I.getOpcode() == Instruction::LShr
+ ? Instruction::ZExt
+ : Instruction::SExt,
+ Cmp, Ty);
}
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
index 62043d1af0c54c..93082de93f97a4 100644
--- a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
+++ b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
@@ -53,6 +53,18 @@ define i8 @scmp_to_sext(i32 %x, i32 %y) {
ret i8 %2
}
+define <4 x i8> @scmp_to_sext_vec(<4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: define <4 x i8> @scmp_to_sext_vec(
+; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <4 x i32> [[X]], [[Y]]
+; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i1> [[TMP1]] to <4 x i8>
+; CHECK-NEXT: ret <4 x i8> [[TMP2]]
+;
+ %1 = call <4 x i8> @llvm.scmp(<4 x i32> %x, <4 x i32> %y)
+ %2 = ashr <4 x i8> %1, <i8 7, i8 7, i8 7, i8 7>
+ ret <4 x i8> %2
+}
+
; Negative test: incorrect shift amount
define i8 @ucmp_to_zext_neg1(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @ucmp_to_zext_neg1(
>From f4b0181862262734a69a13f130c75dfe48655693 Mon Sep 17 00:00:00 2001
From: Volodymyr Vasylkun <vvmposeydon at gmail.com>
Date: Wed, 14 Aug 2024 17:11:38 +0100
Subject: [PATCH 4/4] Update
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Co-authored-by: Nikita Popov <github at npopov.com>
---
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 6852a41b68c556..794b384d126eb6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -514,7 +514,7 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
Instruction *CmpIntr;
if ((I.getOpcode() == Instruction::LShr ||
I.getOpcode() == Instruction::AShr) &&
- match(Op0, m_Instruction(CmpIntr)) && CmpIntr->hasOneUse() &&
+ match(Op0, m_OneUse(m_Instruction(CmpIntr))) &&
isa<CmpIntrinsic>(CmpIntr) &&
match(Op1, m_SpecificInt(Ty->getScalarSizeInBits() - 1))) {
Value *Cmp =
More information about the llvm-commits
mailing list