[llvm] [LoopIdiomRecognizer] Implement CRC recognition (PR #79295)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 24 06:28:59 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Joe Faulls (joe-img)

<details>
<summary>Changes</summary>

Recognizes CRC byte loops and replaces them with a table lookup.

Current limitations:
- Only works on byte loops
	CRC size can be any, but the data is limited to one byte. i.e. a loop with iteration count 8.
- Only works on single-block loops
	most CRC loops would have been flattened to one block with select instructions this far into the pipeline. 

Both limitations were  in effort to reduce complexity, especially for a first patch. The code can be fairly easily extended to overcome these limitations.

Implementation details:
1) Check if the loop looks like CRC and extract some useful information
2) Execute one iteration of the instruction of the loop to see what happens to our output value
4) Ensure the output value is predicated on the value of the M/LSB of our input
4) Construct an expected output value of one iteration of CRC using the extracted information from step one and compare
5) Construct a lookup table and replace the output value with a lookup

---

Patch is 46.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79295.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+910-9) 
- (added) llvm/test/Transforms/LoopIdiom/crc/crc.ll (+195) 
- (added) llvm/test/Transforms/LoopIdiom/crc/not-crc.ll (+113) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 3721564890ddb4..f20947daaed8d5 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -90,6 +90,8 @@
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
+#include <map>
+#include <sstream>
 #include <utility>
 #include <vector>
 
@@ -135,6 +137,9 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
              "with -Os/-Oz"),
     cl::init(true), cl::Hidden);
 
+static cl::opt<bool> CRCRecognize("recognize-crc", cl::desc("CRC RECOGNIZE"),
+                                  cl::init(false), cl::Hidden);
+
 namespace {
 
 class LoopIdiomRecognize {
@@ -186,6 +191,15 @@ class LoopIdiomRecognize {
             // handling.
   };
 
+  struct CRCInfo {
+    Value *CRCInput;
+    Value *CRCOutput;
+    Value *DataInput;
+    uint64_t Width;
+    uint64_t Polynomial;
+    bool BitReversed;
+  };
+
   /// \name Countable Loop Idiom Handling
   /// @{
 
@@ -242,6 +256,9 @@ class LoopIdiomRecognize {
 
   bool recognizeShiftUntilBitTest();
   bool recognizeShiftUntilZero();
+  std::optional<CRCInfo> looksLikeCRC(const SCEV *BECount);
+  bool recognizeCRC(const SCEV *BECount);
+  void writeTableBasedCRCOneByte(CRCInfo &CRC);
 
   /// @}
 };
@@ -298,13 +315,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
   ApplyCodeSizeHeuristics =
       L->getHeader()->getParent()->hasOptSize() && UseLIRCodeSizeHeurs;
 
-  HasMemset = TLI->has(LibFunc_memset);
-  HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
-  HasMemcpy = TLI->has(LibFunc_memcpy);
-
-  if (HasMemset || HasMemsetPattern || HasMemcpy)
-    if (SE->hasLoopInvariantBackedgeTakenCount(L))
-      return runOnCountableLoop();
+  if (SE->hasLoopInvariantBackedgeTakenCount(L))
+    return runOnCountableLoop();
 
   return runOnNoncountableLoop();
 }
@@ -329,6 +341,17 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
                     << "] Countable Loop %" << CurLoop->getHeader()->getName()
                     << "\n");
 
+  bool MadeChange = false;
+  if (CRCRecognize)
+    MadeChange |= recognizeCRC(BECount);
+
+  HasMemset = TLI->has(LibFunc_memset);
+  HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
+  HasMemcpy = TLI->has(LibFunc_memcpy);
+
+  if (!(HasMemset || HasMemsetPattern || HasMemcpy))
+    return MadeChange;
+
   // The following transforms hoist stores/memsets into the loop pre-header.
   // Give up if the loop has instructions that may throw.
   SimpleLoopSafetyInfo SafetyInfo;
@@ -336,8 +359,6 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
   if (SafetyInfo.anyBlockMayThrow())
     return false;
 
-  bool MadeChange = false;
-
   // Scan all the blocks in the loop that are not in subloops.
   for (auto *BB : CurLoop->getBlocks()) {
     // Ignore blocks in subloops.
@@ -2868,3 +2889,883 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
   ++NumShiftUntilZero;
   return MadeChange;
 }
+
+static uint64_t reverseBits(uint64_t Num, unsigned NumBits) {
+  uint64_t Reversed = 0;
+  for (unsigned i = 1; i <= NumBits; i++) {
+    Reversed |= (Num & 1) << (NumBits - i);
+    Num >>= 1;
+  }
+  return Reversed;
+}
+
+class ValueBits {
+  // This is a representation of a value's bits in terms of references to
+  // other values' bits, or 1/0 if the bit is known. This allows symbolic
+  // execution of bitwise instructions without knowing the exact values.
+  //
+  // Example:
+  //
+  // LLVM IR Value i8 %x:
+  // [%x[7], %x[6], %x[5], %x[4], %x[3], %x[2], %x[1], %x[0]]
+  //
+  // %shr = lshr i8 %x, 2
+  // [ 0, 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2]]
+  //
+  // %shl = shl i8 %shr, 1
+  // [ 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2], 0]
+  //
+  // %xor = xor i8 %shl, 0xb
+  // [ 0, %x[7], %x[6], %x[5], %x[4]^1, %x[3], %x[2]^1, 1]
+public:
+  class ValueBit {
+  public:
+    enum BitType { ONE, ZERO, REF, XOR };
+
+  private:
+    BitType _Type;
+    std::pair<Value *, uint64_t> _BitRef;
+    ValueBit *_LHS;
+    ValueBit *_RHS;
+
+    ValueBit(BitType Type) : _Type(Type) {}
+    ValueBit(BitType Type, std::pair<Value *, uint64_t> BitRef)
+        : _Type(Type), _BitRef(BitRef) {}
+    ValueBit(BitType Type, ValueBit *LHS, ValueBit *RHS)
+        : _Type(Type), _LHS(LHS), _RHS(RHS) {}
+
+  public:
+    static ValueBit *CreateOneBit() { return new ValueBit(BitType::ONE); }
+    static ValueBit *CreateZeroBit() { return new ValueBit(BitType::ZERO); }
+    static ValueBit *CreateRefBit(Value *Ref, uint64_t Offset) {
+      return new ValueBit(BitType::REF, std::make_pair(Ref, Offset));
+    }
+    static ValueBit *CreateXORBit(ValueBit *LHS, ValueBit *RHS) {
+      return new ValueBit(BitType::XOR, LHS, RHS);
+    }
+    inline BitType getType() { return _Type; }
+    bool equals(ValueBit *RHS) {
+      if (_Type != RHS->getType())
+        return false;
+      switch (_Type) {
+      case BitType::ONE:
+      case BitType::ZERO:
+        return true;
+      case BitType::REF:
+        return _BitRef == RHS->_BitRef;
+      case BitType::XOR:
+        return (_LHS->equals(RHS->_LHS) && _RHS->equals(RHS->_RHS)) ||
+               (_LHS->equals(RHS->_RHS) && _RHS->equals(RHS->_LHS));
+      }
+      return false;
+    }
+
+    void print(raw_ostream &OS) {
+      switch (_Type) {
+      case BitType::ONE:
+        OS << "1";
+        break;
+      case BitType::ZERO:
+        OS << "0";
+        break;
+      case BitType::REF:
+        OS << _BitRef.first->getNameOrAsOperand() << "[" << _BitRef.second
+           << "]";
+        break;
+      case BitType::XOR:
+        _LHS->print(OS);
+        OS << "^";
+        _RHS->print(OS);
+        break;
+      }
+    }
+  };
+
+private:
+  uint64_t Size;
+  std::vector<ValueBit *> Bits;
+
+  virtual void _Shl(uint64_t N) {
+    for (; N > 0; N--) {
+      Bits.insert(Bits.begin(), ValueBit::CreateZeroBit());
+      Bits.pop_back();
+    }
+  }
+  virtual void _LShr(uint64_t N) {
+    for (; N > 0; N--) {
+      Bits.insert(Bits.end(), ValueBit::CreateZeroBit());
+      Bits.erase(Bits.begin());
+    }
+  }
+  virtual void _Xor(ValueBits *RHS) {
+    assert(Size == RHS->getSize());
+    for (unsigned I = 0; I < Size; I++) {
+      auto It = Bits.begin() + I;
+      ValueBit *RHSBit = RHS->getBit(I);
+      if (RHSBit->getType() == ValueBit::BitType::ONE) {
+        Bits.erase(It);
+        if ((*It)->getType() == ValueBit::BitType::ZERO) {
+          Bits.insert(It, ValueBit::CreateOneBit());
+        } else if ((*It)->getType() == ValueBit::BitType::ONE) {
+          Bits.insert(It, ValueBit::CreateZeroBit());
+        } else {
+          ValueBit *One = ValueBit::CreateOneBit();
+          Bits.insert(It, ValueBit::CreateXORBit(*It, One));
+        }
+      } else if (RHSBit->getType() != ValueBit::BitType::ZERO) {
+        if ((*It)->getType() == ValueBit::BitType::ZERO) {
+          Bits.erase(It);
+          ValueBit *BitRef = new ValueBit(*RHSBit);
+          Bits.insert(It, BitRef);
+        } else {
+          ValueBit *ItVB = *It;
+          Bits.erase(It);
+          Bits.insert(It, ValueBit::CreateXORBit(ItVB, RHSBit));
+        }
+      }
+    }
+  }
+  virtual void _ZExt(uint64_t ToSize) {
+    assert(ToSize > Size);
+    for (uint64_t I = 0; I < ToSize - Size; I++)
+      Bits.push_back(ValueBit::CreateZeroBit());
+    Size = ToSize;
+  }
+  virtual void _Trunc(uint64_t ToSize) {
+    assert(ToSize < Size);
+    Bits.erase(Bits.begin() + ToSize, Bits.end());
+    Size = ToSize;
+  }
+  virtual void _And(uint64_t RHS) {
+    for (unsigned I = 0; I < Size; I++) {
+      if (!(RHS & 1)) {
+        auto It = Bits.begin() + I;
+        Bits.erase(It);
+        Bits.insert(It, ValueBit::CreateZeroBit());
+      }
+      RHS >>= 1;
+    }
+  }
+
+protected:
+  ValueBits() {}
+
+public:
+  ValueBits(Value *InitialVal, uint64_t BitLength) : Size(BitLength) {
+    for (unsigned i = 0; i < BitLength; i++)
+      Bits.push_back(ValueBit::CreateRefBit(InitialVal, i));
+  }
+  ValueBits(uint64_t InitialVal, uint64_t BitLength) : Size(BitLength) {
+    for (unsigned i = 0; i < BitLength; i++) {
+      if (InitialVal & 0x1)
+        Bits.push_back(ValueBit::CreateOneBit());
+      else
+        Bits.push_back(ValueBit::CreateZeroBit());
+      InitialVal >>= 1;
+    }
+  }
+  uint64_t getSize() { return Size; }
+  ValueBit *getBit(unsigned i) { return Bits[i]; }
+
+  virtual ValueBits *copyBits() { return new ValueBits(*this); }
+
+  static ValueBits *Shl(ValueBits *LHS, uint64_t N) {
+    ValueBits *Shifted = LHS->copyBits();
+    Shifted->_Shl(N);
+    return Shifted;
+  }
+  static ValueBits *LShr(ValueBits *LHS, uint64_t N) {
+    ValueBits *Shifted = LHS->copyBits();
+    Shifted->_LShr(N);
+    return Shifted;
+  }
+  static ValueBits *Xor(ValueBits *LHS, ValueBits *RHS) {
+    ValueBits *Xord = LHS->copyBits();
+    Xord->_Xor(RHS);
+    return Xord;
+  }
+  static ValueBits *ZExt(ValueBits *LHS, uint64_t ToSize) {
+    ValueBits *Zexted = LHS->copyBits();
+    Zexted->_ZExt(ToSize);
+    return Zexted;
+  }
+  static ValueBits *Trunc(ValueBits *LHS, uint64_t N) {
+    ValueBits *Trunced = LHS->copyBits();
+    Trunced->_Trunc(N);
+    return Trunced;
+  }
+  static ValueBits *And(ValueBits *LHS, uint64_t RHS) {
+    ValueBits *Anded = LHS->copyBits();
+    Anded->_And(RHS);
+    return Anded;
+  }
+
+  virtual bool isPredicated() { return false; }
+
+  virtual bool equals(ValueBits *RHS) {
+    if (Size != RHS->getSize())
+      return false;
+
+    for (unsigned I = 0; I < Size; I++)
+      if (!getBit(I)->equals(RHS->getBit(I)))
+        return false;
+
+    return true;
+  }
+
+  virtual void print(raw_ostream &OS) {
+    assert(Size != 0);
+    OS << "[";
+    Bits[Size - 1]->print(OS);
+    for (int i = Size - 2; i >= 0; i--) {
+      OS << " | ";
+      Bits[i]->print(OS);
+    }
+    OS << "]\n";
+  }
+};
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits &VBS) {
+  VBS.print(OS);
+  return OS;
+}
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits::ValueBit &VB) {
+  VB.print(OS);
+  return OS;
+}
+class PredicatedValueBits : public ValueBits {
+  // This would be representitive of select or phi instructions where the bits
+  // would depend on an icmp.
+private:
+  ICmpInst *_Predicate;
+  ValueBits *_IfTrue;
+  ValueBits *_IfFalse;
+
+  void _Shl(uint64_t N) override {
+    _IfTrue = ValueBits::Shl(_IfTrue, N);
+    _IfFalse = ValueBits::Shl(_IfFalse, N);
+  }
+  void _LShr(uint64_t N) override {
+    _IfTrue = ValueBits::LShr(_IfTrue, N);
+    _IfFalse = ValueBits::LShr(_IfFalse, N);
+  }
+  void _ZExt(uint64_t N) override {
+    _IfTrue = ValueBits::ZExt(_IfTrue, N);
+    _IfFalse = ValueBits::ZExt(_IfFalse, N);
+  }
+  void _And(uint64_t N) override {
+    _IfTrue = ValueBits::And(_IfTrue, N);
+    _IfFalse = ValueBits::And(_IfFalse, N);
+  }
+  void _Xor(ValueBits *RHS) override {
+    _IfTrue = ValueBits::Xor(_IfTrue, RHS);
+    _IfFalse = ValueBits::Xor(_IfFalse, RHS);
+  }
+  void _Trunc(uint64_t N) override {
+    _IfTrue = ValueBits::Trunc(_IfTrue, N);
+    _IfFalse = ValueBits::Trunc(_IfFalse, N);
+  }
+
+public:
+  PredicatedValueBits(ICmpInst *Predicate, ValueBits *IfTrue,
+                      ValueBits *IfFalse)
+      : _Predicate(Predicate), _IfTrue(IfTrue), _IfFalse(IfFalse) {}
+
+  ValueBits *copyBits() override { return new PredicatedValueBits(*this); }
+  bool isPredicated() override { return true; }
+  ValueBits *getIfTrue() { return _IfTrue; }
+  ValueBits *getIfFalse() { return _IfFalse; }
+  ICmpInst *getPredicate() { return _Predicate; }
+
+  virtual void print(raw_ostream &OS) override {
+    OS << "Predicate: " << *_Predicate << "\nIf True:\n"
+       << *_IfTrue << "If False:\n"
+       << *_IfFalse;
+  }
+};
+
+// Execute the instructions in a basic block whilst mapping out Values to
+// ValueBits
+static bool symbolicallyExecute(BasicBlock *BB,
+                                std::map<Value *, ValueBits *> &ValueMap) {
+
+  auto getConstantOperand = [](Instruction *I, uint8_t Operand) {
+    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(Operand));
+    if (!CI) {
+      LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Do not know how to"
+                        << " handle this operation with non-constant operand "
+                        << Operand << ":\n"
+                        << *I << "\n");
+    }
+    return CI;
+  };
+
+  auto getOrCreateValueBits = [&ValueMap](Value *Val) {
+    auto Result = ValueMap.find(Val);
+    ValueBits *LHSBits = nullptr;
+    if (Result == ValueMap.end()) {
+      ConstantInt *CI = dyn_cast<ConstantInt>(Val);
+      if (CI) {
+        LHSBits = new ValueBits(CI->getSExtValue(),
+                                Val->getType()->getScalarSizeInBits());
+      } else {
+        LHSBits = new ValueBits(Val, Val->getType()->getScalarSizeInBits());
+      }
+    } else
+      LHSBits = Result->second;
+    return LHSBits;
+  };
+
+  for (Instruction &I : *BB) {
+    uint64_t BitSize = I.getType()->getScalarSizeInBits();
+    switch (I.getOpcode()) {
+    case Instruction::PHI: {
+      PHINode *PHI = dyn_cast<PHINode>(&I);
+      const BasicBlock *IncomingBlock = nullptr;
+      for (const BasicBlock *Incoming : PHI->blocks()) {
+        if (Incoming != BB) {
+          if (IncomingBlock) {
+            LLVM_DEBUG(dbgs()
+                       << DEBUG_TYPE " CRCRegonize: Do not know how to"
+                       << " handle loop with multiple entries" << I << "\n");
+            return false;
+          }
+          IncomingBlock = Incoming;
+        }
+      }
+      assert(IncomingBlock);
+      ValueMap[&I] =
+          getOrCreateValueBits(PHI->getIncomingValueForBlock(IncomingBlock));
+    } break;
+    case Instruction::Shl: {
+      ConstantInt *CI = getConstantOperand(&I, 1);
+      if (!CI)
+        return false;
+      Value *LHSVal = I.getOperand(0);
+      ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+      ValueMap[&I] = ValueBits::Shl(LHSBits, CI->getSExtValue());
+    } break;
+    case Instruction::LShr: {
+      ConstantInt *CI = getConstantOperand(&I, 1);
+      if (!CI)
+        return false;
+      Value *LHSVal = I.getOperand(0);
+      ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+      ValueMap[&I] = ValueBits::LShr(LHSBits, CI->getSExtValue());
+    } break;
+    case Instruction::And: {
+      ConstantInt *CI = getConstantOperand(&I, 1);
+      if (!CI)
+        return false;
+      Value *LHSVal = I.getOperand(0);
+      ValueBits *LHSBits = getOrCreateValueBits(LHSVal);
+      ValueMap[&I] = ValueBits::And(LHSBits, CI->getSExtValue());
+    } break;
+    case Instruction::Xor: {
+      ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+      ValueBits *RHSBits = getOrCreateValueBits(I.getOperand(1));
+      ValueMap[&I] = ValueBits::Xor(LHSBits, RHSBits);
+    } break;
+    case Instruction::ZExt: {
+      ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+      ValueMap[&I] = ValueBits::ZExt(LHSBits, BitSize);
+    } break;
+    case Instruction::Trunc: {
+      ValueBits *LHSBits = getOrCreateValueBits(I.getOperand(0));
+      ValueMap[&I] = ValueBits::Trunc(LHSBits, BitSize);
+    } break;
+    case Instruction::Select: {
+      SelectInst *Select = cast<SelectInst>(&I);
+      ICmpInst *Cond = dyn_cast<ICmpInst>(Select->getCondition());
+      if (!Cond) {
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Do not know how to"
+                          << " handle SelectInst with non-icmp condition: " << I
+                          << "\n");
+        return false;
+      }
+      ValueBits *IfTrue = getOrCreateValueBits(Select->getTrueValue());
+      ValueBits *IfFalse = getOrCreateValueBits(Select->getFalseValue());
+      ValueMap[&I] = new PredicatedValueBits(Cond, IfTrue, IfFalse);
+    } break;
+    default:
+      // If this instruction is not recognized, then just continue. This is
+      // okay because users of this will just reference it by value, which is
+      // conservative.
+      break;
+    }
+  }
+  return true;
+}
+
+void LoopIdiomRecognize::writeTableBasedCRCOneByte(CRCInfo &CRC) {
+  BasicBlock *ExitBB = CurLoop->getExitBlock();
+  IRBuilder<> Builder(ExitBB);
+  Builder.SetInsertPoint(ExitBB->getFirstNonPHI());
+  Type *CRCType = CRC.CRCInput->getType();
+  uint64_t CRCSize = CRCType->getScalarSizeInBits();
+
+  // Construct the CRC table
+  uint64_t CRCTable[256];
+  uint64_t Polynomial = CRC.Polynomial;
+  uint64_t SB = CRC.BitReversed ? 0x1 : (0x1 << (CRCSize - 1));
+  if (CRC.BitReversed)
+    Polynomial = reverseBits(Polynomial, CRCSize);
+  for (uint64_t Dividend = 0; Dividend < 256; Dividend++) {
+    uint64_t CurByte = Dividend;
+    if (!CRC.BitReversed)
+      CurByte <<= CRCSize - 8;
+    for (uint8_t Bit = 0; Bit < 8; Bit++) {
+      if ((CurByte & SB) != 0) {
+        CurByte = CRC.BitReversed ? CurByte >> 1 : CurByte << 1;
+        CurByte = CurByte ^ Polynomial;
+      } else {
+        CurByte = CRC.BitReversed ? CurByte >> 1 : CurByte << 1;
+      }
+    }
+    CRCTable[Dividend] = CurByte;
+  }
+  // To construct a global data array, we need the raw data in bytes.
+  // The calculated table array is an array of 64bit values because we can't
+  // dynamically type it, so we need to truncate the values to the crc size
+  // to avoid padded zeros. Do this by allocating a byte array (of slightly more
+  // than we need to account for overflow) and copying the 64bit values across
+  // aligned correctly
+  uint64_t CRCNumBytes = CRCSize / 8;
+  char *CRCTableData = (char *)malloc(CRCNumBytes * 260);
+  for (int I = 0; I < 256; I++) {
+    *((uint64_t *)(CRCTableData + I * CRCNumBytes)) = CRCTable[I];
+  }
+
+  // Construct and add the table as a global variable
+  ArrayType *TableType = ArrayType::get(CRCType, 256);
+  Constant *ConstantArr = ConstantDataArray::getRaw(
+      StringRef(CRCTableData, CRCNumBytes * 256), 256, CRCType);
+  std::stringstream TableNameSS;
+  TableNameSS << "crctable.i" << CRCSize << "." << CRC.Polynomial;
+  if (CRC.BitReversed)
+    TableNameSS << ".reversed";
+  GlobalVariable *CRCTableGlobal = new GlobalVariable(
+      TableType, true, GlobalVariable::LinkageTypes::PrivateLinkage,
+      ConstantArr, TableNameSS.str());
+  ExitBB->getModule()->insertGlobalVariable(CRCTableGlobal);
+  free(CRCTableData);
+
+  // Construct the IR to load from this table
+  Value *CRCOffset = CRC.CRCInput;
+  if (CRCSize > 8) {
+    // Get the next byte into position and truncate
+    if (!CRC.BitReversed)
+      CRCOffset = Builder.CreateLShr(CRCOffset, CRCSize - 8);
+    CRCOffset = Builder.CreateTrunc(CRCOffset, Builder.getInt8Ty());
+  }
+  if (CRC.DataInput) {
+    // Data size can be more than 8 due to extending
+    Value *Data = CRC.DataInput;
+    if (CRC.DataInput->getType()->getScalarSizeInBits() > 8) {
+      Data = Builder.CreateTrunc(Data, Builder.getInt8Ty());
+    }
+    // Xor the data, offset into the table and load
+    CRCOffset = Builder.CreateXor(CRCOffset, Data);
+  }
+
+  CRCOffset = Builder.CreateZExt(CRCOffset, Builder.getInt32Ty());
+  Value *Gep = Builder.CreateInBoundsGEP(CRCType, CRCTableGlobal, {CRCOffset});
+  Value *CRCRes = Builder.CreateLoad(CRCType, Gep);
+  if (CRCSize > 8) {
+    // Shift out SB used for division and Xor the rest of the crc back in
+    Value *RestOfCRC = CRC.CRCInput;
+    if (CRC.BitReversed)
+      RestOfCRC = Builder.CreateLShr(CRC.CRCInput, 8);
+    else
+      RestOfCRC = Builder.CreateShl(CRC.CRCInput, 8);
+    CRCRes = Builder.CreateXor(RestOfCRC, CRCRes);
+  }
+  for (PHINode &ExitPhi : CurLoop->getExitBlock()->phis()) {
+    if (ExitPhi.getNumIncomingValues() == 1 &&
+        ExitPhi.getIncomingValue(0) == CRC.CRCOutput)
+      ExitPhi.replaceAllUsesWith(CRCRes);
+  }
+}
+
+bool LoopIdiomRecognize::recognizeCRC(const SCEV *BECount) {
+  // Step one: Check if the loop looks like crc, and extract some useful
+  // information for us to check
+  std::optional<CRCInfo> MaybeCRC = looksLikeCRC(BECount);
+  if (!MaybeCRC)
+    return false;
+  CRCInfo CRC = *MaybeCRC;
+
+  uint64_t CRCSize = CRC.CRCInput->getType()->getScalarSizeInBits();
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Found potential CRCLoop "
+                    << *CurLoop << "\n"
+                    << "Input CRC: " << *CRC.CRCInput << "\n"
+                    << "Output CRC: " << *CRC.CRCOutput << "\n"
+           ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79295


More information about the llvm-commits mailing list