[llvm] r373961 - [InstCombine][NFC] dropRedundantMaskingOfLeftShiftInput(): change how we deal with mask

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 13:53:00 PDT 2019


Author: lebedevri
Date: Mon Oct  7 13:53:00 2019
New Revision: 373961

URL: http://llvm.org/viewvc/llvm-project?rev=373961&view=rev
Log:
[InstCombine][NFC] dropRedundantMaskingOfLeftShiftInput(): change how we deal with mask

Summary:
Currently, we pre-check whether we need to produce a mask or not.
This involves some rather magical constants.
I'd like to extend this fold to also handle the situation
when there's also a `trunc` before outer shift.
That will require another set of magical constants.
It's ugly.

Instead, we can just compute the mask, and check
whether mask is a pass-through (all-ones) or not.
This way we don't need to have any magical numbers.

This change is NFC other than the fact that we now compute
the mask and then check if we need (and can!) apply it.

Reviewers: spatel

Reviewed By: spatel

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68470

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp?rev=373961&r1=373960&r2=373961&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp Mon Oct  7 13:53:00 2019
@@ -181,39 +181,29 @@ dropRedundantMaskingOfLeftShiftInput(Bin
         MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
     if (!SumOfShAmts)
       return nullptr; // Did not simplify.
+    // In this pattern SumOfShAmts correlates with the number of low bits
+    // that shall remain in the root value (OuterShift).
+
     Type *Ty = X->getType();
-    unsigned BitWidth = Ty->getScalarSizeInBits();
-    // In this pattern SumOfShAmts correlates with the number of low bits that
-    // shall remain in the root value (OuterShift). If SumOfShAmts is less than
-    // bitwidth, we'll need to also produce a mask to keep SumOfShAmts low bits.
-    // So, does *any* channel need a mask?
-    if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
-                                               APInt(BitWidth, BitWidth)))) {
-      // But for a mask we need to get rid of old masking instruction.
-      if (!Masked->hasOneUse())
-        return nullptr; // Else we can't perform the fold.
-      // The mask must be computed in a type twice as wide to ensure
-      // that no bits are lost if the sum-of-shifts is wider than the base type.
-      Type *ExtendedTy = Ty->getExtendedType();
-      // An extend of an undef value becomes zero because the high bits are
-      // never completely unknown. Replace the the `undef` shift amounts with
-      // final shift bitwidth to ensure that the value remains undef when
-      // creating the subsequent shift op.
-      SumOfShAmts = replaceUndefsWith(
-          SumOfShAmts,
-          ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
-                           ExtendedTy->getScalarType()->getScalarSizeInBits()));
-      auto *ExtendedSumOfShAmts =
-          ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
-      // And compute the mask as usual: ~(-1 << (SumOfShAmts))
-      auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
-      auto *ExtendedInvertedMask =
-          ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
-      auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
-      NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
-    } else
-      NewMask = nullptr; // No mask needed.
-    // All good, we can do this fold.
+
+    // The mask must be computed in a type twice as wide to ensure
+    // that no bits are lost if the sum-of-shifts is wider than the base type.
+    Type *ExtendedTy = Ty->getExtendedType();
+    // An extend of an undef value becomes zero because the high bits are never
+    // completely unknown. Replace the the `undef` shift amounts with final
+    // shift bitwidth to ensure that the value remains undef when creating the
+    // subsequent shift op.
+    SumOfShAmts = replaceUndefsWith(
+        SumOfShAmts,
+        ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
+                         ExtendedTy->getScalarType()->getScalarSizeInBits()));
+    auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+    // And compute the mask as usual: ~(-1 << (SumOfShAmts))
+    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+    auto *ExtendedInvertedMask =
+        ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
+    auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
+    NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
   } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
              match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
                                  m_Deferred(MaskShAmt)))) {
@@ -223,49 +213,51 @@ dropRedundantMaskingOfLeftShiftInput(Bin
     if (!ShAmtsDiff)
       return nullptr; // Did not simplify.
     // In this pattern ShAmtsDiff correlates with the number of high bits that
-    // shall be unset in the root value (OuterShift). If ShAmtsDiff is negative,
-    // we'll need to also produce a mask to unset ShAmtsDiff high bits.
-    // So, does *any* channel need a mask? (is ShiftShAmt u>= MaskShAmt ?)
-    if (!match(ShAmtsDiff, m_NonNegative())) {
-      // This sub-fold (with mask) is invalid for 'ashr' "masking" instruction.
-      if (match(Masked, m_AShr(m_Value(), m_Value())))
-        return nullptr;
-      // For a mask we need to get rid of old masking instruction.
-      if (!Masked->hasOneUse())
-        return nullptr; // Else we can't perform the fold.
-      Type *Ty = X->getType();
-      unsigned BitWidth = Ty->getScalarSizeInBits();
-      // The mask must be computed in a type twice as wide to ensure
-      // that no bits are lost if the sum-of-shifts is wider than the base type.
-      Type *ExtendedTy = Ty->getExtendedType();
-      // An extend of an undef value becomes zero because the high bits are
-      // never completely unknown. Replace the the `undef` shift amounts with
-      // negated shift bitwidth to ensure that the value remains undef when
-      // creating the subsequent shift op.
-      ShAmtsDiff = replaceUndefsWith(
-          ShAmtsDiff,
-          ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
-      auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
-          ConstantExpr::getAdd(
-              ConstantExpr::getNeg(ShAmtsDiff),
-              ConstantInt::get(Ty, BitWidth, /*isSigned=*/false)),
-          ExtendedTy);
-      // And compute the mask as usual: (-1 l>> (ShAmtsDiff))
-      auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
-      auto *ExtendedMask =
-          ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
-      NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
-    } else
-      NewMask = nullptr; // No mask needed.
-    // All good, we can do this fold.
+    // shall be unset in the root value (OuterShift).
+
+    Type *Ty = X->getType();
+    unsigned BitWidth = Ty->getScalarSizeInBits();
+
+    // The mask must be computed in a type twice as wide to ensure
+    // that no bits are lost if the sum-of-shifts is wider than the base type.
+    Type *ExtendedTy = Ty->getExtendedType();
+    // An extend of an undef value becomes zero because the high bits are never
+    // completely unknown. Replace the the `undef` shift amounts with negated
+    // shift bitwidth to ensure that the value remains undef when creating the
+    // subsequent shift op.
+    ShAmtsDiff = replaceUndefsWith(
+        ShAmtsDiff,
+        ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
+    auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+        ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), BitWidth,
+                                              /*isSigned=*/false),
+                             ShAmtsDiff),
+        ExtendedTy);
+    // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
+    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+    auto *ExtendedMask =
+        ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+    NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
   } else
     return nullptr; // Don't know anything about this pattern.
 
-  // No 'NUW'/'NSW'!
-  // We no longer know that we won't shift-out non-0 bits.
+  // Does this mask has any unset bits? If not then we can just not apply it.
+  bool NeedMask = !match(NewMask, m_AllOnes());
+
+  // If we need to apply a mask, there are several more restrictions we have.
+  if (NeedMask) {
+    // The old masking instruction must go away.
+    if (!Masked->hasOneUse())
+      return nullptr;
+    // The original "masking" instruction must not have been`ashr`.
+    if (match(Masked, m_AShr(m_Value(), m_Value())))
+      return nullptr;
+  }
+
+  // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
   auto *NewShift =
       BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt);
-  if (!NewMask)
+  if (!NeedMask)
     return NewShift;
 
   Builder.Insert(NewShift);




More information about the llvm-commits mailing list