[llvm] [LoopIdiomRecognizer] Implement CRC recognition (PR #79295)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 21 19:58:43 PST 2024
================
@@ -2868,3 +2889,888 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
++NumShiftUntilZero;
return MadeChange;
}
+
+static uint64_t reverseBits(uint64_t Num, unsigned NumBits) {
+ uint64_t Reversed = 0;
+ for (unsigned i = 1; i <= NumBits; i++) {
+ Reversed |= (Num & 1) << (NumBits - i);
+ Num >>= 1;
+ }
+ return Reversed;
+}
+
+class ValueBits {
+ // This is a representation of a value's bits in terms of references to
+ // other values' bits, or 1/0 if the bit is known. This allows symbolic
+ // execution of bitwise instructions without knowing the exact values.
+ //
+ // Example:
+ //
+ // LLVM IR Value i8 %x:
+ // [%x[7], %x[6], %x[5], %x[4], %x[3], %x[2], %x[1], %x[0]]
+ //
+ // %shr = lshr i8 %x, 2
+ // [ 0, 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2]]
+ //
+ // %shl = shl i8 %shr, 1
+ // [ 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2], 0]
+ //
+ // %xor = xor i8 %shl, 0xb
+ // [ 0, %x[7], %x[6], %x[5], %x[4]^1, %x[3], %x[2]^1, 1]
+public:
+ class ValueBit {
+ public:
+ enum BitType { ONE, ZERO, REF, XOR };
+
+ private:
+ BitType _Type;
+ std::pair<Value *, uint64_t> _BitRef;
+ ValueBit *_LHS;
+ ValueBit *_RHS;
+
+ ValueBit(BitType Type) : _Type(Type) {}
+ ValueBit(BitType Type, std::pair<Value *, uint64_t> BitRef)
+ : _Type(Type), _BitRef(BitRef) {}
+ ValueBit(BitType Type, ValueBit *LHS, ValueBit *RHS)
+ : _Type(Type), _LHS(LHS), _RHS(RHS) {}
+
+ public:
+ static ValueBit *CreateOneBit() { return new ValueBit(BitType::ONE); }
+ static ValueBit *CreateZeroBit() { return new ValueBit(BitType::ZERO); }
+ static ValueBit *CreateRefBit(Value *Ref, uint64_t Offset) {
+ return new ValueBit(BitType::REF, std::make_pair(Ref, Offset));
+ }
+ static ValueBit *CreateXORBit(ValueBit *LHS, ValueBit *RHS) {
+ return new ValueBit(BitType::XOR, LHS, RHS);
+ }
+ inline BitType getType() { return _Type; }
+ bool equals(ValueBit *RHS) {
+ if (_Type != RHS->getType())
+ return false;
+ switch (_Type) {
+ case BitType::ONE:
+ case BitType::ZERO:
+ return true;
+ case BitType::REF:
+ return _BitRef == RHS->_BitRef;
+ case BitType::XOR:
+ return (_LHS->equals(RHS->_LHS) && _RHS->equals(RHS->_RHS)) ||
+ (_LHS->equals(RHS->_RHS) && _RHS->equals(RHS->_LHS));
+ }
+ return false;
+ }
+
+ void print(raw_ostream &OS) {
+ switch (_Type) {
+ case BitType::ONE:
+ OS << "1";
+ break;
+ case BitType::ZERO:
+ OS << "0";
+ break;
+ case BitType::REF:
+ OS << _BitRef.first->getNameOrAsOperand() << "[" << _BitRef.second
+ << "]";
+ break;
+ case BitType::XOR:
+ _LHS->print(OS);
+ OS << "^";
+ _RHS->print(OS);
+ break;
+ }
+ }
+ };
+
+private:
+ uint64_t Size;
+ std::vector<ValueBit *> Bits;
+
+ virtual void _Shl(uint64_t N) {
+ for (; N > 0; N--) {
+ Bits.insert(Bits.begin(), ValueBit::CreateZeroBit());
+ Bits.pop_back();
+ }
+ }
+ virtual void _LShr(uint64_t N) {
+ for (; N > 0; N--) {
+ Bits.insert(Bits.end(), ValueBit::CreateZeroBit());
+ Bits.erase(Bits.begin());
+ }
+ }
+ virtual void _Xor(ValueBits *RHS) {
+ assert(Size == RHS->getSize());
+ for (unsigned I = 0; I < Size; I++) {
+ auto It = Bits.begin() + I;
+ ValueBit *RHSBit = RHS->getBit(I);
+ if (RHSBit->getType() == ValueBit::BitType::ONE) {
+ Bits.erase(It);
+ if ((*It)->getType() == ValueBit::BitType::ZERO) {
+ Bits.insert(It, ValueBit::CreateOneBit());
+ } else if ((*It)->getType() == ValueBit::BitType::ONE) {
+ Bits.insert(It, ValueBit::CreateZeroBit());
+ } else {
+ ValueBit *One = ValueBit::CreateOneBit();
+ Bits.insert(It, ValueBit::CreateXORBit(*It, One));
+ }
+ } else if (RHSBit->getType() != ValueBit::BitType::ZERO) {
+ if ((*It)->getType() == ValueBit::BitType::ZERO) {
+ Bits.erase(It);
+ ValueBit *BitRef = new ValueBit(*RHSBit);
+ Bits.insert(It, BitRef);
+ } else {
+ ValueBit *ItVB = *It;
+ Bits.erase(It);
+ Bits.insert(It, ValueBit::CreateXORBit(ItVB, RHSBit));
+ }
+ }
+ }
+ }
+ virtual void _ZExt(uint64_t ToSize) {
+ assert(ToSize > Size);
+ for (uint64_t I = 0; I < ToSize - Size; I++)
+ Bits.push_back(ValueBit::CreateZeroBit());
+ Size = ToSize;
+ }
+ virtual void _Trunc(uint64_t ToSize) {
+ assert(ToSize < Size);
+ Bits.erase(Bits.begin() + ToSize, Bits.end());
+ Size = ToSize;
+ }
+ virtual void _And(uint64_t RHS) {
+ for (unsigned I = 0; I < Size; I++) {
+ if (!(RHS & 1)) {
+ auto It = Bits.begin() + I;
+ Bits.erase(It);
+ Bits.insert(It, ValueBit::CreateZeroBit());
+ }
+ RHS >>= 1;
+ }
+ }
+
+protected:
+ ValueBits() {}
+
+public:
+ ValueBits(Value *InitialVal, uint64_t BitLength) : Size(BitLength) {
+ for (unsigned i = 0; i < BitLength; i++)
+ Bits.push_back(ValueBit::CreateRefBit(InitialVal, i));
+ }
+ ValueBits(uint64_t InitialVal, uint64_t BitLength) : Size(BitLength) {
+ for (unsigned i = 0; i < BitLength; i++) {
+ if (InitialVal & 0x1)
+ Bits.push_back(ValueBit::CreateOneBit());
+ else
+ Bits.push_back(ValueBit::CreateZeroBit());
+ InitialVal >>= 1;
+ }
+ }
+ uint64_t getSize() { return Size; }
+ ValueBit *getBit(unsigned i) { return Bits[i]; }
+
+ virtual ValueBits *copyBits() { return new ValueBits(*this); }
+
+ static ValueBits *Shl(ValueBits *LHS, uint64_t N) {
+ ValueBits *Shifted = LHS->copyBits();
+ Shifted->_Shl(N);
+ return Shifted;
+ }
+ static ValueBits *LShr(ValueBits *LHS, uint64_t N) {
+ ValueBits *Shifted = LHS->copyBits();
+ Shifted->_LShr(N);
+ return Shifted;
+ }
+ static ValueBits *Xor(ValueBits *LHS, ValueBits *RHS) {
+ ValueBits *Xord = LHS->copyBits();
+ Xord->_Xor(RHS);
+ return Xord;
+ }
+ static ValueBits *ZExt(ValueBits *LHS, uint64_t ToSize) {
+ ValueBits *Zexted = LHS->copyBits();
+ Zexted->_ZExt(ToSize);
+ return Zexted;
+ }
+ static ValueBits *Trunc(ValueBits *LHS, uint64_t N) {
+ ValueBits *Trunced = LHS->copyBits();
+ Trunced->_Trunc(N);
+ return Trunced;
+ }
+ static ValueBits *And(ValueBits *LHS, uint64_t RHS) {
+ ValueBits *Anded = LHS->copyBits();
+ Anded->_And(RHS);
+ return Anded;
+ }
+
+ virtual bool isPredicated() { return false; }
+
+ virtual bool equals(ValueBits *RHS) {
+ if (Size != RHS->getSize())
+ return false;
+
+ for (unsigned I = 0; I < Size; I++)
+ if (!getBit(I)->equals(RHS->getBit(I)))
+ return false;
+
+ return true;
+ }
+
+ virtual void print(raw_ostream &OS) {
+ assert(Size != 0);
+ OS << "[";
+ Bits[Size - 1]->print(OS);
+ for (int i = Size - 2; i >= 0; i--) {
+ OS << " | ";
+ Bits[i]->print(OS);
+ }
+ OS << "]\n";
+ }
+};
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits &VBS) {
+ VBS.print(OS);
+ return OS;
+}
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits::ValueBit &VB) {
+ VB.print(OS);
+ return OS;
+}
+class PredicatedValueBits : public ValueBits {
+ // This would be representitive of select or phi instructions where the bits
+ // would depend on an icmp.
+private:
+ ICmpInst *_Predicate;
+ ValueBits *_IfTrue;
+ ValueBits *_IfFalse;
+
+ void _Shl(uint64_t N) override {
+ _IfTrue = ValueBits::Shl(_IfTrue, N);
+ _IfFalse = ValueBits::Shl(_IfFalse, N);
+ }
+ void _LShr(uint64_t N) override {
+ _IfTrue = ValueBits::LShr(_IfTrue, N);
+ _IfFalse = ValueBits::LShr(_IfFalse, N);
+ }
+ void _ZExt(uint64_t N) override {
+ _IfTrue = ValueBits::ZExt(_IfTrue, N);
+ _IfFalse = ValueBits::ZExt(_IfFalse, N);
+ }
+ void _And(uint64_t N) override {
+ _IfTrue = ValueBits::And(_IfTrue, N);
+ _IfFalse = ValueBits::And(_IfFalse, N);
+ }
+ void _Xor(ValueBits *RHS) override {
+ _IfTrue = ValueBits::Xor(_IfTrue, RHS);
+ _IfFalse = ValueBits::Xor(_IfFalse, RHS);
+ }
+ void _Trunc(uint64_t N) override {
+ _IfTrue = ValueBits::Trunc(_IfTrue, N);
+ _IfFalse = ValueBits::Trunc(_IfFalse, N);
+ }
+
+public:
+ PredicatedValueBits(ICmpInst *Predicate, ValueBits *IfTrue,
+ ValueBits *IfFalse)
+ : _Predicate(Predicate), _IfTrue(IfTrue), _IfFalse(IfFalse) {}
+
+ ValueBits *copyBits() override { return new PredicatedValueBits(*this); }
+ bool isPredicated() override { return true; }
+ ValueBits *getIfTrue() { return _IfTrue; }
+ ValueBits *getIfFalse() { return _IfFalse; }
+ ICmpInst *getPredicate() { return _Predicate; }
+
+ virtual void print(raw_ostream &OS) override {
+ OS << "Predicate: " << *_Predicate << "\nIf True:\n"
+ << *_IfTrue << "If False:\n"
+ << *_IfFalse;
+ }
+};
+
+// Execute the instructions in a basic block whilst mapping out Values to
+// ValueBits
+static bool symbolicallyExecute(BasicBlock *BB,
+ std::map<Value *, ValueBits *> &ValueMap) {
+
+ auto getConstantOperand = [](Instruction *I, uint8_t Operand) {
+ ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(Operand));
+ if (!CI) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Do not know how to"
+ << " handle this operation with non-constant operand "
+ << Operand << ":\n"
+ << *I << "\n");
+ }
+ return CI;
+ };
+
+ auto getOrCreateValueBits = [&ValueMap](Value *Val) {
+ auto Result = ValueMap.find(Val);
+ ValueBits *LHSBits = nullptr;
+ if (Result == ValueMap.end()) {
+ ConstantInt *CI = dyn_cast<ConstantInt>(Val);
+ if (CI) {
+ LHSBits = new ValueBits(CI->getSExtValue(),
+ Val->getType()->getScalarSizeInBits());
+ } else {
+ LHSBits = new ValueBits(Val, Val->getType()->getScalarSizeInBits());
+ }
+ } else
+ LHSBits = Result->second;
+ return LHSBits;
+ };
+
+ for (Instruction &I : *BB) {
+ uint64_t BitSize = I.getType()->getScalarSizeInBits();
+ switch (I.getOpcode()) {
+ case Instruction::PHI: {
+ PHINode *PHI = dyn_cast<PHINode>(&I);
+ const BasicBlock *IncomingBlock = nullptr;
+ for (const BasicBlock *Incoming : PHI->blocks()) {
+ if (Incoming != BB) {
+ if (IncomingBlock) {
+ LLVM_DEBUG(dbgs()
+ << DEBUG_TYPE " CRCRegonize: Do not know how to"
+ << " handle loop with multiple entries" << I << "\n");
+ return false;
+ }
+ IncomingBlock = Incoming;
+ }
+ }
+ assert(IncomingBlock);
+ ValueMap[&I] =
+ getOrCreateValueBits(PHI->getIncomingValueForBlock(IncomingBlock));
+ } break;
+ case Instruction::Shl: {
+ ConstantInt *CI = getConstantOperand(&I, 1);
+ if (!CI)
+ return false;
+ Value *LHSVal = I.getOperand(0);
+ ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+ ValueMap[&I] = ValueBits::Shl(LHSBits, CI->getSExtValue());
+ } break;
+ case Instruction::LShr: {
+ ConstantInt *CI = getConstantOperand(&I, 1);
+ if (!CI)
+ return false;
+ Value *LHSVal = I.getOperand(0);
+ ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+ ValueMap[&I] = ValueBits::LShr(LHSBits, CI->getSExtValue());
+ } break;
+ case Instruction::And: {
+ ConstantInt *CI = getConstantOperand(&I, 1);
+ if (!CI)
+ return false;
+ Value *LHSVal = I.getOperand(0);
+ ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+ ValueMap[&I] = ValueBits::And(LHSBits, CI->getSExtValue());
+ } break;
+ case Instruction::Xor: {
+ ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+ ValueBits *RHSBits = getOrCreateValueBits(I.getOperand(1));
+ ValueMap[&I] = ValueBits::Xor(LHSBits, RHSBits);
+ } break;
+ case Instruction::ZExt: {
+ ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+ ValueMap[&I] = ValueBits::ZExt(LHSBits, BitSize);
+ } break;
+ case Instruction::Trunc: {
+ ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+ ValueMap[&I] = ValueBits::Trunc(LHSBits, BitSize);
+ } break;
+ case Instruction::Select: {
+ SelectInst *Select = cast<SelectInst>(&I);
+ ICmpInst *Cond = dyn_cast<ICmpInst>(Select->getCondition());
+ if (!Cond) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Do not know how to"
+ << " handle SelectInst with non-icmp condition: " << I
+ << "\n");
+ return false;
+ }
+ ValueBits *IfTrue = getOrCreateValueBits(Select->getTrueValue());
+ ValueBits *IfFalse = getOrCreateValueBits(Select->getFalseValue());
+ ValueMap[&I] = new PredicatedValueBits(Cond, IfTrue, IfFalse);
+ } break;
+ default:
+ // If this instruction is not recognized, then just continue. This is
+ // okay because users of this will just reference it by value, which is
+ // conservative.
+ break;
+ }
+ }
+ return true;
+}
+
+void LoopIdiomRecognize::writeTableBasedCRCOneByte(CRCInfo &CRC) {
+ BasicBlock *ExitBB = CurLoop->getExitBlock();
+ IRBuilder<> Builder(ExitBB);
+ Builder.SetInsertPoint(ExitBB->getFirstNonPHI());
+ Type *CRCType = CRC.CRCInput->getType();
+ uint64_t CRCSize = CRCType->getScalarSizeInBits();
+
+ // Construct the CRC table
+ uint64_t CRCTable[256];
+ uint64_t Polynomial = CRC.Polynomial;
+ uint64_t SB = CRC.BitReversed ? 0x1 : (0x1 << (CRCSize - 1));
+ if (CRC.BitReversed)
+ Polynomial = reverseBits(Polynomial, CRCSize);
+ for (uint64_t Dividend = 0; Dividend < 256; Dividend++) {
+ uint64_t CurByte = Dividend;
+ if (!CRC.BitReversed)
+ CurByte <<= CRCSize - 8;
+ for (uint8_t Bit = 0; Bit < 8; Bit++) {
+ if ((CurByte & SB) != 0) {
+ CurByte = CRC.BitReversed ? CurByte >> 1 : CurByte << 1;
+ CurByte = CurByte ^ Polynomial;
+ } else {
+ CurByte = CRC.BitReversed ? CurByte >> 1 : CurByte << 1;
+ }
+ }
+ CRCTable[Dividend] = CurByte;
+ }
+ // To construct a global data array, we need the raw data in bytes.
+ // The calculated table array is an array of 64bit values because we can't
+ // dynamically type it, so we need to truncate the values to the crc size
+ // to avoid padded zeros. Do this by allocating a byte array (of slightly more
+ // than we need to account for overflow) and copying the 64bit values across
+ // aligned correctly
+ uint64_t CRCNumBytes = CRCSize / 8;
+ char *CRCTableData = (char *)malloc(CRCNumBytes * 260);
----------------
topperc wrote:
Please use array `new` and `delete`.
And use uint8_t or int8_t instead of char. The signedness of char is platform dependent
https://github.com/llvm/llvm-project/pull/79295
More information about the llvm-commits
mailing list