[llvm] [InstCombine] Improve shamt range calculation (PR #72646)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 27 21:36:08 PST 2023


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/72646

>From 94dcda5865a6e598069bf025a508e25361df2b30 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 17 Nov 2023 20:34:47 +0800
Subject: [PATCH] [InstCombine] Improve shamt range calculation

---
 llvm/include/llvm/Analysis/ValueTracking.h            | 5 +++++
 llvm/lib/Analysis/ValueTracking.cpp                   | 8 ++++----
 llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 7 ++++---
 llvm/test/Transforms/InstCombine/and-narrow.ll        | 2 +-
 llvm/test/Transforms/InstCombine/shift-shift.ll       | 4 ++--
 5 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 01eb8532d1f56d2..ba7bfcab03482f6 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -888,6 +888,11 @@ ConstantRange computeConstantRange(const Value *V, bool ForSigned,
                                    const DominatorTree *DT = nullptr,
                                    unsigned Depth = 0);
 
+/// Combine constant ranges from computeConstantRange() and computeKnownBits().
+ConstantRange
+computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
+                                       bool ForSigned, const SimplifyQuery &SQ);
+
 /// Return true if this function can prove that the instruction I will
 /// always transfer execution to one of its successors (including the next
 /// instruction that follows within a basic block). E.g. this is not
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 19af949e5216e15..1d7ba9cd111ef00 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6233,10 +6233,10 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
 }
 
 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
-static ConstantRange
-computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
-                                       bool ForSigned,
-                                       const SimplifyQuery &SQ) {
+ConstantRange
+llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
+                                             bool ForSigned,
+                                             const SimplifyQuery &SQ) {
   ConstantRange CR1 =
       ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
   ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c72eb0f74de8e6d..f70780baa11afaa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -12,6 +12,7 @@
 
 #include "InstCombineInternal.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -962,12 +963,12 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   }
 
   // Compute what we know about shift count.
-  KnownBits KnownCnt =
-      computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  ConstantRange KnownCnt = computeConstantRangeIncludingKnownBits(
+      I.getOperand(1), /* ForSigned */ false, Q);
   unsigned BitWidth = KnownCnt.getBitWidth();
   // Since shift produces a poison value if RHS is equal to or larger than the
   // bit width, we can safely assume that RHS is less than the bit width.
-  uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
+  uint64_t MaxCnt = KnownCnt.getUnsignedMax().getLimitedValue(BitWidth - 1);
 
   KnownBits KnownAmt =
       computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
diff --git a/llvm/test/Transforms/InstCombine/and-narrow.ll b/llvm/test/Transforms/InstCombine/and-narrow.ll
index c8c720f5fbc5534..0cc74008144b738 100644
--- a/llvm/test/Transforms/InstCombine/and-narrow.ll
+++ b/llvm/test/Transforms/InstCombine/and-narrow.ll
@@ -190,7 +190,7 @@ define <2 x i16> @zext_lshr_vec_undef(<2 x i8> %x) {
 define <2 x i16> @zext_shl_vec_overshift(<2 x i8> %x) {
 ; CHECK-LABEL: @zext_shl_vec_overshift(
 ; CHECK-NEXT:    [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
-; CHECK-NEXT:    [[B:%.*]] = shl <2 x i16> [[Z]], <i16 8, i16 2>
+; CHECK-NEXT:    [[B:%.*]] = shl nuw <2 x i16> [[Z]], <i16 8, i16 2>
 ; CHECK-NEXT:    [[R:%.*]] = and <2 x i16> [[B]], [[Z]]
 ; CHECK-NEXT:    ret <2 x i16> [[R]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/shift-shift.ll b/llvm/test/Transforms/InstCombine/shift-shift.ll
index 4270a2c2ba205ed..ad25160d10c8a64 100644
--- a/llvm/test/Transforms/InstCombine/shift-shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift-shift.ll
@@ -531,7 +531,7 @@ define <2 x i6> @shl_lshr_demand5_undef_right(<2 x i8> %x) {
 define <2 x i6> @shl_lshr_demand5_nonuniform_vec_left(<2 x i8> %x) {
 ; CHECK-LABEL: @shl_lshr_demand5_nonuniform_vec_left(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i8> <i8 -108, i8 -108>, [[X:%.*]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr <2 x i8> [[SHL]], <i8 1, i8 2>
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr exact <2 x i8> [[SHL]], <i8 1, i8 2>
 ; CHECK-NEXT:    [[R:%.*]] = trunc <2 x i8> [[LSHR]] to <2 x i6>
 ; CHECK-NEXT:    ret <2 x i6> [[R]]
 ;
@@ -694,7 +694,7 @@ define <2 x i8> @lshr_shl_demand5_undef_right(<2 x i8> %x) {
 define <2 x i8> @lshr_shl_demand5_nonuniform_vec_left(<2 x i8> %x) {
 ; CHECK-LABEL: @lshr_shl_demand5_nonuniform_vec_left(
 ; CHECK-NEXT:    [[SHR:%.*]] = lshr <2 x i8> <i8 45, i8 45>, [[X:%.*]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i8> [[SHR]], <i8 1, i8 2>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw <2 x i8> [[SHR]], <i8 1, i8 2>
 ; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[SHL]], <i8 108, i8 108>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;



More information about the llvm-commits mailing list