[llvm] bed5876 - [AggressiveInstCombine] Add arithmetic shift right instr to `TruncInstCombine` DAG

Anton Afanasyev via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 24 00:41:35 PDT 2021


Author: Anton Afanasyev
Date: 2021-08-24T10:41:16+03:00
New Revision: bed587631f9051f879b0672e52a58b9e8b8faab9

URL: https://github.com/llvm/llvm-project/commit/bed587631f9051f879b0672e52a58b9e8b8faab9
DIFF: https://github.com/llvm/llvm-project/commit/bed587631f9051f879b0672e52a58b9e8b8faab9.diff

LOG: [AggressiveInstCombine] Add arithmetic shift right instr to `TruncInstCombine` DAG

Add `ashr` instruction to the DAG post-dominated by `trunc`, allowing
`TruncInstCombine` to reduce bitwidth of expressions containing
these instructions.

We should be shifting by less than the target bitwidth.
Also it is sufficient to require that all truncated bits
of the value-to-be-shifted are sign bits (all zeros or ones) and
one sign bit is left untruncated: https://alive2.llvm.org/ce/z/Ajo2__

Part of https://reviews.llvm.org/D107766

Differential Revision: https://reviews.llvm.org/D108355

Added: 
    

Modified: 
    llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
    llvm/test/Transforms/AggressiveInstCombine/trunc_ashr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 10a5f47104d5f..5d66533f04e0f 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -65,6 +65,7 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
   case Instruction::Xor:
   case Instruction::Shl:
   case Instruction::LShr:
+  case Instruction::AShr:
     Ops.push_back(I->getOperand(0));
     Ops.push_back(I->getOperand(1));
     break;
@@ -133,6 +134,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
     case Instruction::Xor:
     case Instruction::Shl:
     case Instruction::LShr:
+    case Instruction::AShr:
     case Instruction::Select: {
       SmallVector<Value *, 2> Operands;
       getRelevantOperands(I, Operands);
@@ -143,8 +145,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
       // TODO: Can handle more cases here:
       // 1. shufflevector, extractelement, insertelement
       // 2. udiv, urem
-      // 3. ashr
-      // 4. phi node(and loop handling)
+      // 3. phi node(and loop handling)
       // ...
       return false;
     }
@@ -277,14 +278,16 @@ Type *TruncInstCombine::getBestTruncatedType() {
       CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
 
   // Initialize MinBitWidth for shift instructions with the minimum number
-  // that is greater than shift amount (i.e. shift amount + 1). For `lshr`
-  // adjust MinBitWidth so that all potentially truncated bits of
-  // the value-to-be-shifted are zeros.
-  // Also normalize MinBitWidth not to be greater than source bitwidth.
+  // that is greater than shift amount (i.e. shift amount + 1).
+  // For `lshr` adjust MinBitWidth so that all potentially truncated
+  // bits of the value-to-be-shifted are zeros.
+  // For `ashr` adjust MinBitWidth so that all potentially truncated
+  // bits of the value-to-be-shifted are sign bits (all zeros or ones)
+  // and even one (first) untruncated bit is sign bit.
+  // Exit early if MinBitWidth is not less than original bitwidth.
   for (auto &Itr : InstInfoMap) {
     Instruction *I = Itr.first;
-    if (I->getOpcode() == Instruction::Shl ||
-        I->getOpcode() == Instruction::LShr) {
+    if (I->isShift()) {
       KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL);
       unsigned MinBitWidth = KnownRHS.getMaxValue()
                                  .uadd_sat(APInt(OrigBitWidth, 1))
@@ -295,9 +298,13 @@ Type *TruncInstCombine::getBestTruncatedType() {
         KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL);
         MinBitWidth =
             std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
-        if (MinBitWidth >= OrigBitWidth)
-          return nullptr;
       }
+      if (I->getOpcode() == Instruction::AShr) {
+        unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0), DL);
+        MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
+      }
+      if (MinBitWidth >= OrigBitWidth)
+        return nullptr;
       Itr.second.MinBitWidth = MinBitWidth;
     }
   }
@@ -390,14 +397,15 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
     case Instruction::Or:
     case Instruction::Xor:
     case Instruction::Shl:
-    case Instruction::LShr: {
+    case Instruction::LShr:
+    case Instruction::AShr: {
       Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
       Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
       Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
       // Preserve `exact` flag since truncation doesn't change exactness
-      if (Opc == Instruction::LShr)
+      if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
         if (auto *ResI = dyn_cast<Instruction>(Res))
-          ResI->setIsExact(I->isExact());
+          ResI->setIsExact(PEO->isExact());
       break;
     }
     case Instruction::Select: {

diff  --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_ashr.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_ashr.ll
index 57f807e677353..512b708af8977 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/trunc_ashr.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_ashr.ll
@@ -19,10 +19,8 @@ define i16 @ashr_15_zext(i16 %x) {
 
 define i16 @ashr_sext_15(i16 %x) {
 ; CHECK-LABEL: @ashr_sext_15(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i16 [[X:%.*]] to i32
-; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[SEXT]], 15
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16
-; CHECK-NEXT:    ret i16 [[TRUNC]]
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i16 [[X:%.*]], 15
+; CHECK-NEXT:    ret i16 [[ASHR]]
 ;
   %sext = sext i16 %x to i32
   %ashr = ashr i32 %sext, 15
@@ -68,14 +66,13 @@ define i16 @ashr_var_shift_amount(i8 %x, i8 %amt) {
 
 define i16 @ashr_var_bounded_shift_amount(i8 %x, i8 %amt) {
 ; CHECK-LABEL: @ashr_var_bounded_shift_amount(
-; CHECK-NEXT:    [[Z:%.*]] = zext i8 [[X:%.*]] to i32
-; CHECK-NEXT:    [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i32
-; CHECK-NEXT:    [[ZA2:%.*]] = and i32 [[ZA]], 15
-; CHECK-NEXT:    [[S:%.*]] = ashr i32 [[Z]], [[ZA2]]
-; CHECK-NEXT:    [[A:%.*]] = add i32 [[S]], [[Z]]
-; CHECK-NEXT:    [[S2:%.*]] = ashr i32 [[A]], 2
-; CHECK-NEXT:    [[T:%.*]] = trunc i32 [[S2]] to i16
-; CHECK-NEXT:    ret i16 [[T]]
+; CHECK-NEXT:    [[Z:%.*]] = zext i8 [[X:%.*]] to i16
+; CHECK-NEXT:    [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i16
+; CHECK-NEXT:    [[ZA2:%.*]] = and i16 [[ZA]], 15
+; CHECK-NEXT:    [[S:%.*]] = ashr i16 [[Z]], [[ZA2]]
+; CHECK-NEXT:    [[A:%.*]] = add i16 [[S]], [[Z]]
+; CHECK-NEXT:    [[S2:%.*]] = ashr i16 [[A]], 2
+; CHECK-NEXT:    ret i16 [[S2]]
 ;
   %z = zext i8 %x to i32
   %za = zext i8 %amt to i32
@@ -108,16 +105,15 @@ define i32 @ashr_check_no_overflow(i32 %x, i16 %amt) {
 
 define void @ashr_big_dag(i16* %a, i8 %b, i8 %c) {
 ; CHECK-LABEL: @ashr_big_dag(
-; CHECK-NEXT:    [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i32
-; CHECK-NEXT:    [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i32
-; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[ZEXT1]], [[ZEXT2]]
-; CHECK-NEXT:    [[SFT1:%.*]] = and i32 [[ADD1]], 15
-; CHECK-NEXT:    [[SHR1:%.*]] = ashr i32 [[ADD1]], [[SFT1]]
-; CHECK-NEXT:    [[ADD2:%.*]] = add i32 [[ADD1]], [[SHR1]]
-; CHECK-NEXT:    [[SFT2:%.*]] = and i32 [[ADD2]], 7
-; CHECK-NEXT:    [[SHR2:%.*]] = ashr i32 [[ADD2]], [[SFT2]]
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[SHR2]] to i16
-; CHECK-NEXT:    store i16 [[TRUNC]], i16* [[A:%.*]], align 2
+; CHECK-NEXT:    [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i16
+; CHECK-NEXT:    [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i16
+; CHECK-NEXT:    [[ADD1:%.*]] = add i16 [[ZEXT1]], [[ZEXT2]]
+; CHECK-NEXT:    [[SFT1:%.*]] = and i16 [[ADD1]], 15
+; CHECK-NEXT:    [[SHR1:%.*]] = ashr i16 [[ADD1]], [[SFT1]]
+; CHECK-NEXT:    [[ADD2:%.*]] = add i16 [[ADD1]], [[SHR1]]
+; CHECK-NEXT:    [[SFT2:%.*]] = and i16 [[ADD2]], 7
+; CHECK-NEXT:    [[SHR2:%.*]] = ashr i16 [[ADD2]], [[SFT2]]
+; CHECK-NEXT:    store i16 [[SHR2]], i16* [[A:%.*]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %zext1 = zext i8 %b to i32
@@ -152,13 +148,12 @@ define i8 @ashr_check_not_i8_trunc(i16 %x) {
 
 define <2 x i16> @ashr_vector(<2 x i8> %x) {
 ; CHECK-LABEL: @ashr_vector(
-; CHECK-NEXT:    [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[ZA:%.*]] = and <2 x i32> [[Z]], <i32 7, i32 8>
-; CHECK-NEXT:    [[S:%.*]] = ashr <2 x i32> [[Z]], [[ZA]]
-; CHECK-NEXT:    [[A:%.*]] = add <2 x i32> [[S]], [[Z]]
-; CHECK-NEXT:    [[S2:%.*]] = ashr <2 x i32> [[A]], <i32 4, i32 5>
-; CHECK-NEXT:    [[T:%.*]] = trunc <2 x i32> [[S2]] to <2 x i16>
-; CHECK-NEXT:    ret <2 x i16> [[T]]
+; CHECK-NEXT:    [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
+; CHECK-NEXT:    [[ZA:%.*]] = and <2 x i16> [[Z]], <i16 7, i16 8>
+; CHECK-NEXT:    [[S:%.*]] = ashr <2 x i16> [[Z]], [[ZA]]
+; CHECK-NEXT:    [[A:%.*]] = add <2 x i16> [[S]], [[Z]]
+; CHECK-NEXT:    [[S2:%.*]] = ashr <2 x i16> [[A]], <i16 4, i16 5>
+; CHECK-NEXT:    ret <2 x i16> [[S2]]
 ;
   %z = zext <2 x i8> %x to <2 x i32>
   %za = and <2 x i32> %z, <i32 7, i32 8>
@@ -213,11 +208,9 @@ define <2 x i16> @ashr_vector_large_shift_amount(<2 x i8> %x) {
 
 define i16 @ashr_exact(i16 %x) {
 ; CHECK-LABEL: @ashr_exact(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[ZEXT]], 32767
-; CHECK-NEXT:    [[ASHR:%.*]] = ashr exact i32 [[AND]], 15
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16
-; CHECK-NEXT:    ret i16 [[TRUNC]]
+; CHECK-NEXT:    [[AND:%.*]] = and i16 [[X:%.*]], 32767
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr exact i16 [[AND]], 15
+; CHECK-NEXT:    ret i16 [[ASHR]]
 ;
   %zext = zext i16 %x to i32
   %and = and i32 %zext, 32767
@@ -245,12 +238,10 @@ define i16 @ashr_negative_operand(i16 %x) {
 
 define i16 @ashr_negative_operand_but_short(i16 %x) {
 ; CHECK-LABEL: @ashr_negative_operand_but_short(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[ZEXT]], 32767
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 -1, [[AND]]
-; CHECK-NEXT:    [[LSHR2:%.*]] = ashr i32 [[XOR]], 2
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[LSHR2]] to i16
-; CHECK-NEXT:    ret i16 [[TRUNC]]
+; CHECK-NEXT:    [[AND:%.*]] = and i16 [[X:%.*]], 32767
+; CHECK-NEXT:    [[XOR:%.*]] = xor i16 -1, [[AND]]
+; CHECK-NEXT:    [[LSHR2:%.*]] = ashr i16 [[XOR]], 2
+; CHECK-NEXT:    ret i16 [[LSHR2]]
 ;
   %zext = zext i16 %x to i32
   %and = and i32 %zext, 32767


        


More information about the llvm-commits mailing list