[llvm] eb0e7ac - [InstCombine] canEvaluateTruncated - use KnownBits to check for inrange shift amounts

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 3 08:02:27 PDT 2020


Author: Simon Pilgrim
Date: 2020-07-03T16:02:10+01:00
New Revision: eb0e7acbd4853c31a1401c8f02586850fee15107

URL: https://github.com/llvm/llvm-project/commit/eb0e7acbd4853c31a1401c8f02586850fee15107
DIFF: https://github.com/llvm/llvm-project/commit/eb0e7acbd4853c31a1401c8f02586850fee15107.diff

LOG: [InstCombine] canEvaluateTruncated - use KnownBits to check for inrange shift amounts

Currently canEvaluateTruncated can only attempt to truncate shifts if they are scalar/uniform constant amounts that are in range.

This patch replaces the constant extraction code with KnownBits handling, using the KnownBits::getMaxValue to check that the amounts are inrange.

This enables support for nonuniform constant cases, and also variable shift amounts that have been masked somehow. Annoyingly, this still won't work for vectors with (demanded) undefs as KnownBits returns nothing in those cases, but its a definite improvement on what we currently have.

Differential Revision: https://reviews.llvm.org/D83127

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
    llvm/test/Transforms/InstCombine/cast.ll
    llvm/test/Transforms/InstCombine/trunc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 8d9ebe457231..7b3c503facf1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -377,29 +377,31 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
     break;
   }
   case Instruction::Shl: {
-    // If we are truncating the result of this SHL, and if it's a shift of a
-    // constant amount, we can always perform a SHL in a smaller type.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth)
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-    }
+    // If we are truncating the result of this SHL, and if it's a shift of an
+    // inrange amount, we can always perform a SHL in a smaller type.
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    if (AmtKnownBits.getMaxValue().ult(BitWidth))
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::LShr: {
     // If this is a truncate of a logical shr, we can truncate it to a smaller
     // lshr iff we know that the bits we would otherwise be shifting in are
     // already zeros.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth &&
-          IC.MaskedValueIsZero(I->getOperand(0),
-            APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) {
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-      }
+    // TODO: It is enough to check that the bits we would be shifting in are
+    //       zero - use AmtKnownBits.getMaxValue().
+    uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+    if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+        IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     }
     break;
   }
@@ -409,15 +411,15 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
     // original type and the sign bit of the truncate type are similar.
     // TODO: It is enough to check that the bits we would be shifting in are
     //       similar to sign bit of the truncate type.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth &&
-          OrigBitWidth - BitWidth <
-              IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-    }
+    uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    unsigned ShiftedBits = OrigBitWidth - BitWidth;
+    if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+        ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::Trunc:

diff  --git a/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll b/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
index 999b5d58f438..89e4a3c1aaed 100644
--- a/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
+++ b/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
@@ -35,12 +35,10 @@ define <2 x i16> @test1_vec(<2 x i16> %a) {
 
 define <2 x i16> @test1_vec_nonuniform(<2 x i16> %a) {
 ; CHECK-LABEL: @test1_vec_nonuniform(
-; CHECK-NEXT:    [[B:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 8, i32 9>
-; CHECK-NEXT:    [[D:%.*]] = mul nuw nsw <2 x i32> [[B]], <i32 5, i32 6>
-; CHECK-NEXT:    [[E:%.*]] = or <2 x i32> [[C]], [[D]]
-; CHECK-NEXT:    [[F:%.*]] = trunc <2 x i32> [[E]] to <2 x i16>
-; CHECK-NEXT:    ret <2 x i16> [[F]]
+; CHECK-NEXT:    [[C:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 8, i16 9>
+; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[A]], <i16 5, i16 6>
+; CHECK-NEXT:    [[E:%.*]] = or <2 x i16> [[C]], [[D]]
+; CHECK-NEXT:    ret <2 x i16> [[E]]
 ;
   %b = zext <2 x i16> %a to <2 x i32>
   %c = lshr <2 x i32> %b, <i32 8, i32 9>

diff  --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 10d59bfff57f..18b411103122 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -502,12 +502,10 @@ define <2 x i16> @test40vec(<2 x i16> %a) {
 
 define <2 x i16> @test40vec_nonuniform(<2 x i16> %a) {
 ; ALL-LABEL: @test40vec_nonuniform(
-; ALL-NEXT:    [[T:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
-; ALL-NEXT:    [[T21:%.*]] = lshr <2 x i32> [[T]], <i32 9, i32 10>
-; ALL-NEXT:    [[T5:%.*]] = shl <2 x i32> [[T]], <i32 8, i32 9>
-; ALL-NEXT:    [[T32:%.*]] = or <2 x i32> [[T21]], [[T5]]
-; ALL-NEXT:    [[R:%.*]] = trunc <2 x i32> [[T32]] to <2 x i16>
-; ALL-NEXT:    ret <2 x i16> [[R]]
+; ALL-NEXT:    [[T21:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 9, i16 10>
+; ALL-NEXT:    [[T5:%.*]] = shl <2 x i16> [[A]], <i16 8, i16 9>
+; ALL-NEXT:    [[T32:%.*]] = or <2 x i16> [[T21]], [[T5]]
+; ALL-NEXT:    ret <2 x i16> [[T32]]
 ;
   %t = zext <2 x i16> %a to <2 x i32>
   %t21 = lshr <2 x i32> %t, <i32 9, i32 10>

diff  --git a/llvm/test/Transforms/InstCombine/trunc.ll b/llvm/test/Transforms/InstCombine/trunc.ll
index 4e9f440978a5..d8a615cc4c9a 100644
--- a/llvm/test/Transforms/InstCombine/trunc.ll
+++ b/llvm/test/Transforms/InstCombine/trunc.ll
@@ -286,12 +286,11 @@ define <2 x i64> @test8_vec(<2 x i32> %A, <2 x i32> %B) {
 
 define <2 x i64> @test8_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test8_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
-; CHECK-NEXT:    [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i128>
-; CHECK-NEXT:    [[E:%.*]] = shl <2 x i128> [[D]], <i128 32, i128 48>
-; CHECK-NEXT:    [[F:%.*]] = or <2 x i128> [[E]], [[C]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[E:%.*]] = shl <2 x i64> [[D]], <i64 32, i64 48>
+; CHECK-NEXT:    [[F:%.*]] = or <2 x i64> [[E]], [[C]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -343,12 +342,11 @@ define i8 @test10(i32 %X) {
 
 define i64 @test11(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test11(
-; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = shl i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = shl i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = zext i32 %A to i128
   %D = zext i32 %B to i128
@@ -360,12 +358,11 @@ define i64 @test11(i32 %A, i32 %B) {
 
 define <2 x i64> @test11_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test11_vec(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -377,12 +374,11 @@ define <2 x i64> @test11_vec(<2 x i32> %A, <2 x i32> %B) {
 
 define <2 x i64> @test11_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test11_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -411,12 +407,11 @@ define <2 x i64> @test11_vec_undef(<2 x i32> %A, <2 x i32> %B) {
 
 define i64 @test12(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test12(
-; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = lshr i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = lshr i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = zext i32 %A to i128
   %D = zext i32 %B to i128
@@ -428,12 +423,11 @@ define i64 @test12(i32 %A, i32 %B) {
 
 define <2 x i64> @test12_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test12_vec(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -445,12 +439,11 @@ define <2 x i64> @test12_vec(<2 x i32> %A, <2 x i32> %B) {
 
 define <2 x i64> @test12_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test12_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -479,12 +472,11 @@ define <2 x i64> @test12_vec_undef(<2 x i32> %A, <2 x i32> %B) {
 
 define i64 @test13(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test13(
-; CHECK-NEXT:    [[C:%.*]] = sext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = sext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = ashr i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = ashr i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = sext i32 %A to i128
   %D = zext i32 %B to i128
@@ -496,12 +488,11 @@ define i64 @test13(i32 %A, i32 %B) {
 
 define <2 x i64> @test13_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test13_vec(
-; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = sext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -513,12 +504,11 @@ define <2 x i64> @test13_vec(<2 x i32> %A, <2 x i32> %B) {
 
 define <2 x i64> @test13_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test13_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = sext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>


        


More information about the llvm-commits mailing list