[llvm] d68d217 - [InstCombine] Fold `ucmp/scmp(x, y) >> N` to `zext/sext(x < y)` when N is one less than the width of the result of `ucmp/scmp` (#104009)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 10:08:26 PDT 2024


Author: Volodymyr Vasylkun
Date: 2024-08-15T18:08:23+01:00
New Revision: d68d2172f9f1f0659b8b4bdbbeb1ccd290a614b5

URL: https://github.com/llvm/llvm-project/commit/d68d2172f9f1f0659b8b4bdbbeb1ccd290a614b5
DIFF: https://github.com/llvm/llvm-project/commit/d68d2172f9f1f0659b8b4bdbbeb1ccd290a614b5.diff

LOG: [InstCombine] Fold `ucmp/scmp(x, y) >> N` to `zext/sext(x < y)` when N is one less than the width of the result of `ucmp/scmp` (#104009)

Proof: https://alive2.llvm.org/ce/z/4diUqN

---------

Co-authored-by: Nikita Popov <github at npopov.com>

Added: 
    llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b682..794b384d126eb6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -511,6 +511,21 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
   if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1))))
     return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1));
 
+  Instruction *CmpIntr;
+  if ((I.getOpcode() == Instruction::LShr ||
+       I.getOpcode() == Instruction::AShr) &&
+      match(Op0, m_OneUse(m_Instruction(CmpIntr))) &&
+      isa<CmpIntrinsic>(CmpIntr) &&
+      match(Op1, m_SpecificInt(Ty->getScalarSizeInBits() - 1))) {
+    Value *Cmp =
+        Builder.CreateICmp(cast<CmpIntrinsic>(CmpIntr)->getLTPredicate(),
+                           CmpIntr->getOperand(0), CmpIntr->getOperand(1));
+    return CastInst::Create(I.getOpcode() == Instruction::LShr
+                                ? Instruction::ZExt
+                                : Instruction::SExt,
+                            Cmp, Ty);
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
new file mode 100644
index 00000000000000..93082de93f97a4
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/lshr-ashr-of-uscmp.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare void @use(i8 %val)
+
+; ucmp/scmp(x, y) >> N folds to either zext(x < y) or sext(x < y)
+; if N is one less than the width of result of ucmp/scmp
+define i8 @ucmp_to_zext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i1 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+  %2 = lshr i8 %1, 7
+  ret i8 %2
+}
+
+define i8 @ucmp_to_sext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_sext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i1 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+  %2 = ashr i8 %1, 7
+  ret i8 %2
+}
+
+define i8 @scmp_to_zext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_to_zext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i1 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.scmp(i32 %x, i32 %y)
+  %2 = lshr i8 %1, 7
+  ret i8 %2
+}
+
+define i8 @scmp_to_sext(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_to_sext(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i1 [[TMP1]] to i8
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.scmp(i32 %x, i32 %y)
+  %2 = ashr i8 %1, 7
+  ret i8 %2
+}
+
+define <4 x i8> @scmp_to_sext_vec(<4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: define <4 x i8> @scmp_to_sext_vec(
+; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <4 x i32> [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i1> [[TMP1]] to <4 x i8>
+; CHECK-NEXT:    ret <4 x i8> [[TMP2]]
+;
+  %1 = call <4 x i8> @llvm.scmp(<4 x i32> %x, <4 x i32> %y)
+  %2 = ashr <4 x i8> %1, <i8 7, i8 7, i8 7, i8 7>
+  ret <4 x i8> %2
+}
+
+; Negative test: incorrect shift amount
+define i8 @ucmp_to_zext_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 5
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+  %2 = lshr i8 %1, 5
+  ret i8 %2
+}
+
+; Negative test: shift amount is not a constant
+define i8 @ucmp_to_zext_neg2(i32 %x, i32 %y, i8 %s) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i8 [[S:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], [[S]]
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+  %2 = lshr i8 %1, %s
+  ret i8 %2
+}
+
+; Negative test: the result of ucmp/scmp is used more than once
+define i8 @ucmp_to_zext_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_to_zext_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    call void @use(i8 [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 7
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
+  call void @use(i8 %1)
+  %2 = lshr i8 %1, 7
+  ret i8 %2
+}


        


More information about the llvm-commits mailing list