[llvm] 29ac9fa - [InstCombine] collectBitParts - convert to use PatterMatch matchers and avoid IntegerType casts.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 1 08:53:47 PDT 2020


Author: Simon Pilgrim
Date: 2020-10-01T16:44:14+01:00
New Revision: 29ac9fae54c9cbd819ce400d42dd2e76bf5259ab

URL: https://github.com/llvm/llvm-project/commit/29ac9fae54c9cbd819ce400d42dd2e76bf5259ab
DIFF: https://github.com/llvm/llvm-project/commit/29ac9fae54c9cbd819ce400d42dd2e76bf5259ab.diff

LOG: [InstCombine] collectBitParts - convert to use PatterMatch matchers and avoid IntegerType casts.

Make sure we're using getScalarSizeInBits instead of cast<IntegerType> to get Type bit widths.

This is preliminary cleanup before we can start adding vector support to the bswap/bitreverse (element level) matching.

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/Local.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 550745673bd9..0fd0dfa24ce9 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -2832,7 +2832,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     return I->second;
 
   auto &Result = BPS[V] = None;
-  auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+  auto BitWidth = V->getType()->getScalarSizeInBits();
 
   // Prevent stack overflow by limiting the recursion depth
   if (Depth == BitPartRecursionMaxDepth) {
@@ -2840,13 +2840,16 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     return Result;
   }
 
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
+  if (auto *I = dyn_cast<Instruction>(V)) {
+    Value *X, *Y;
+    const APInt *C;
+
     // If this is an or instruction, it may be an inner node of the bswap.
-    if (I->getOpcode() == Instruction::Or) {
-      const auto &A = collectBitParts(I->getOperand(0), MatchBSwaps,
-                                      MatchBitReversals, BPS, Depth + 1);
-      const auto &B = collectBitParts(I->getOperand(1), MatchBSwaps,
-                                      MatchBitReversals, BPS, Depth + 1);
+    if (match(V, m_Or(m_Value(X), m_Value(Y)))) {
+      const auto &A =
+          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &B =
+          collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
       if (!A || !B)
         return Result;
 
@@ -2871,15 +2874,15 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     }
 
     // If this is a logical shift by a constant, recurse then shift the result.
-    if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
-      const APInt &BitShift = cast<ConstantInt>(I->getOperand(1))->getValue();
+    if (match(V, m_LogicalShift(m_Value(X), m_APInt(C)))) {
+      const APInt &BitShift = *C;
 
       // Ensure the shift amount is defined.
       if (BitShift.uge(BitWidth))
         return Result;
 
-      const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
-                                        MatchBitReversals, BPS, Depth + 1);
+      const auto &Res =
+          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
       if (!Res)
         return Result;
       Result = Res;
@@ -2899,9 +2902,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
 
     // If this is a logical 'and' with a mask that clears bits, recurse then
     // unset the appropriate bits.
-    if (I->getOpcode() == Instruction::And &&
-        isa<ConstantInt>(I->getOperand(1))) {
-      const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
+    if (match(V, m_And(m_Value(X), m_APInt(C)))) {
+      const APInt &AndMask = *C;
 
       // Check that the mask allows a multiple of 8 bits for a bswap, for an
       // early exit.
@@ -2909,8 +2911,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
       if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
         return Result;
 
-      const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
-                                        MatchBitReversals, BPS, Depth + 1);
+      const auto &Res =
+          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
       if (!Res)
         return Result;
       Result = Res;
@@ -2923,15 +2925,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
     }
 
     // If this is a zext instruction zero extend the result.
-    if (I->getOpcode() == Instruction::ZExt) {
-      const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
-                                        MatchBitReversals, BPS, Depth + 1);
+    if (match(V, m_ZExt(m_Value(X)))) {
+      const auto &Res =
+          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
       if (!Res)
         return Result;
 
       Result = BitPart(Res->Provider, BitWidth);
-      auto NarrowBitWidth =
-          cast<IntegerType>(cast<ZExtInst>(I)->getSrcTy())->getBitWidth();
+      auto NarrowBitWidth = X->getType()->getScalarSizeInBits();
       for (unsigned BitIdx = 0; BitIdx < NarrowBitWidth; ++BitIdx)
         Result->Provenance[BitIdx] = Res->Provenance[BitIdx];
       for (unsigned BitIdx = NarrowBitWidth; BitIdx < BitWidth; ++BitIdx)
@@ -2939,40 +2940,33 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
       return Result;
     }
 
-    // Handle intrinsic calls.
-    if (auto *II = dyn_cast<IntrinsicInst>(I)) {
-      Intrinsic::ID IntrinsicID = II->getIntrinsicID();
-
-      // Funnel 'double' shifts take 3 operands, 2 inputs and the shift
-      // amount (modulo).
-      // fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
-      // fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
-      const APInt *Amt;
-      if ((IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) &&
-          match(II->getArgOperand(2), m_APInt(Amt))) {
-
-        // We can treat fshr as a fshl by flipping the modulo amount.
-        unsigned ModAmt = Amt->urem(BitWidth);
-        if (IntrinsicID == Intrinsic::fshr)
-          ModAmt = BitWidth - ModAmt;
-
-        const auto &LHS = collectBitParts(II->getArgOperand(0), MatchBSwaps,
-                                          MatchBitReversals, BPS, Depth + 1);
-        const auto &RHS = collectBitParts(II->getArgOperand(1), MatchBSwaps,
-                                          MatchBitReversals, BPS, Depth + 1);
-
-        // Check we have both sources and they are from the same provider.
-        if (!LHS || !RHS || !LHS->Provider || LHS->Provider != RHS->Provider)
-          return Result;
-
-        unsigned StartBitRHS = BitWidth - ModAmt;
-        Result = BitPart(LHS->Provider, BitWidth);
-        for (unsigned BitIdx = 0; BitIdx < StartBitRHS; ++BitIdx)
-          Result->Provenance[BitIdx + ModAmt] = LHS->Provenance[BitIdx];
-        for (unsigned BitIdx = 0; BitIdx < ModAmt; ++BitIdx)
-          Result->Provenance[BitIdx] = RHS->Provenance[BitIdx + StartBitRHS];
+    // Funnel 'double' shifts take 3 operands, 2 inputs and the shift
+    // amount (modulo).
+    // fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
+    // fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
+    if (match(V, m_FShl(m_Value(X), m_Value(Y), m_APInt(C))) ||
+        match(V, m_FShr(m_Value(X), m_Value(Y), m_APInt(C)))) {
+      // We can treat fshr as a fshl by flipping the modulo amount.
+      unsigned ModAmt = C->urem(BitWidth);
+      if (cast<IntrinsicInst>(I)->getIntrinsicID() == Intrinsic::fshr)
+        ModAmt = BitWidth - ModAmt;
+
+      const auto &LHS =
+          collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+      const auto &RHS =
+          collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+
+      // Check we have both sources and they are from the same provider.
+      if (!LHS || !RHS || !LHS->Provider || LHS->Provider != RHS->Provider)
         return Result;
-      }
+
+      unsigned StartBitRHS = BitWidth - ModAmt;
+      Result = BitPart(LHS->Provider, BitWidth);
+      for (unsigned BitIdx = 0; BitIdx < StartBitRHS; ++BitIdx)
+        Result->Provenance[BitIdx + ModAmt] = LHS->Provenance[BitIdx];
+      for (unsigned BitIdx = 0; BitIdx < ModAmt; ++BitIdx)
+        Result->Provenance[BitIdx] = RHS->Provenance[BitIdx + StartBitRHS];
+      return Result;
     }
   }
 


        


More information about the llvm-commits mailing list