[llvm] 29ac9fa - [InstCombine] collectBitParts - convert to use PatterMatch matchers and avoid IntegerType casts.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 1 08:53:47 PDT 2020
Author: Simon Pilgrim
Date: 2020-10-01T16:44:14+01:00
New Revision: 29ac9fae54c9cbd819ce400d42dd2e76bf5259ab
URL: https://github.com/llvm/llvm-project/commit/29ac9fae54c9cbd819ce400d42dd2e76bf5259ab
DIFF: https://github.com/llvm/llvm-project/commit/29ac9fae54c9cbd819ce400d42dd2e76bf5259ab.diff
LOG: [InstCombine] collectBitParts - convert to use PatterMatch matchers and avoid IntegerType casts.
Make sure we're using getScalarSizeInBits instead of cast<IntegerType> to get Type bit widths.
This is preliminary cleanup before we can start adding vector support to the bswap/bitreverse (element level) matching.
Added:
Modified:
llvm/lib/Transforms/Utils/Local.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 550745673bd9..0fd0dfa24ce9 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -2832,7 +2832,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return I->second;
auto &Result = BPS[V] = None;
- auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+ auto BitWidth = V->getType()->getScalarSizeInBits();
// Prevent stack overflow by limiting the recursion depth
if (Depth == BitPartRecursionMaxDepth) {
@@ -2840,13 +2840,16 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;
}
- if (Instruction *I = dyn_cast<Instruction>(V)) {
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ Value *X, *Y;
+ const APInt *C;
+
// If this is an or instruction, it may be an inner node of the bswap.
- if (I->getOpcode() == Instruction::Or) {
- const auto &A = collectBitParts(I->getOperand(0), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
- const auto &B = collectBitParts(I->getOperand(1), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
+ if (match(V, m_Or(m_Value(X), m_Value(Y)))) {
+ const auto &A =
+ collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &B =
+ collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
if (!A || !B)
return Result;
@@ -2871,15 +2874,15 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
}
// If this is a logical shift by a constant, recurse then shift the result.
- if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
- const APInt &BitShift = cast<ConstantInt>(I->getOperand(1))->getValue();
+ if (match(V, m_LogicalShift(m_Value(X), m_APInt(C)))) {
+ const APInt &BitShift = *C;
// Ensure the shift amount is defined.
if (BitShift.uge(BitWidth))
return Result;
- const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
+ const auto &Res =
+ collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;
Result = Res;
@@ -2899,9 +2902,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a logical 'and' with a mask that clears bits, recurse then
// unset the appropriate bits.
- if (I->getOpcode() == Instruction::And &&
- isa<ConstantInt>(I->getOperand(1))) {
- const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
+ if (match(V, m_And(m_Value(X), m_APInt(C)))) {
+ const APInt &AndMask = *C;
// Check that the mask allows a multiple of 8 bits for a bswap, for an
// early exit.
@@ -2909,8 +2911,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
return Result;
- const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
+ const auto &Res =
+ collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;
Result = Res;
@@ -2923,15 +2925,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
}
// If this is a zext instruction zero extend the result.
- if (I->getOpcode() == Instruction::ZExt) {
- const auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
+ if (match(V, m_ZExt(m_Value(X)))) {
+ const auto &Res =
+ collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
if (!Res)
return Result;
Result = BitPart(Res->Provider, BitWidth);
- auto NarrowBitWidth =
- cast<IntegerType>(cast<ZExtInst>(I)->getSrcTy())->getBitWidth();
+ auto NarrowBitWidth = X->getType()->getScalarSizeInBits();
for (unsigned BitIdx = 0; BitIdx < NarrowBitWidth; ++BitIdx)
Result->Provenance[BitIdx] = Res->Provenance[BitIdx];
for (unsigned BitIdx = NarrowBitWidth; BitIdx < BitWidth; ++BitIdx)
@@ -2939,40 +2940,33 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;
}
- // Handle intrinsic calls.
- if (auto *II = dyn_cast<IntrinsicInst>(I)) {
- Intrinsic::ID IntrinsicID = II->getIntrinsicID();
-
- // Funnel 'double' shifts take 3 operands, 2 inputs and the shift
- // amount (modulo).
- // fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
- // fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
- const APInt *Amt;
- if ((IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) &&
- match(II->getArgOperand(2), m_APInt(Amt))) {
-
- // We can treat fshr as a fshl by flipping the modulo amount.
- unsigned ModAmt = Amt->urem(BitWidth);
- if (IntrinsicID == Intrinsic::fshr)
- ModAmt = BitWidth - ModAmt;
-
- const auto &LHS = collectBitParts(II->getArgOperand(0), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
- const auto &RHS = collectBitParts(II->getArgOperand(1), MatchBSwaps,
- MatchBitReversals, BPS, Depth + 1);
-
- // Check we have both sources and they are from the same provider.
- if (!LHS || !RHS || !LHS->Provider || LHS->Provider != RHS->Provider)
- return Result;
-
- unsigned StartBitRHS = BitWidth - ModAmt;
- Result = BitPart(LHS->Provider, BitWidth);
- for (unsigned BitIdx = 0; BitIdx < StartBitRHS; ++BitIdx)
- Result->Provenance[BitIdx + ModAmt] = LHS->Provenance[BitIdx];
- for (unsigned BitIdx = 0; BitIdx < ModAmt; ++BitIdx)
- Result->Provenance[BitIdx] = RHS->Provenance[BitIdx + StartBitRHS];
+ // Funnel 'double' shifts take 3 operands, 2 inputs and the shift
+ // amount (modulo).
+ // fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
+ // fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
+ if (match(V, m_FShl(m_Value(X), m_Value(Y), m_APInt(C))) ||
+ match(V, m_FShr(m_Value(X), m_Value(Y), m_APInt(C)))) {
+ // We can treat fshr as a fshl by flipping the modulo amount.
+ unsigned ModAmt = C->urem(BitWidth);
+ if (cast<IntrinsicInst>(I)->getIntrinsicID() == Intrinsic::fshr)
+ ModAmt = BitWidth - ModAmt;
+
+ const auto &LHS =
+ collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+ const auto &RHS =
+ collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
+
+ // Check we have both sources and they are from the same provider.
+ if (!LHS || !RHS || !LHS->Provider || LHS->Provider != RHS->Provider)
return Result;
- }
+
+ unsigned StartBitRHS = BitWidth - ModAmt;
+ Result = BitPart(LHS->Provider, BitWidth);
+ for (unsigned BitIdx = 0; BitIdx < StartBitRHS; ++BitIdx)
+ Result->Provenance[BitIdx + ModAmt] = LHS->Provenance[BitIdx];
+ for (unsigned BitIdx = 0; BitIdx < ModAmt; ++BitIdx)
+ Result->Provenance[BitIdx] = RHS->Provenance[BitIdx + StartBitRHS];
+ return Result;
}
}
More information about the llvm-commits
mailing list