[llvm] [LoopIdiomRecognizer] Implement CRC recognition (PR #79295)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 19:58:43 PST 2024


================
@@ -2868,3 +2889,888 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
   ++NumShiftUntilZero;
   return MadeChange;
 }
+
+static uint64_t reverseBits(uint64_t Num, unsigned NumBits) {
+  uint64_t Reversed = 0;
+  for (unsigned i = 1; i <= NumBits; i++) {
+    Reversed |= (Num & 1) << (NumBits - i);
+    Num >>= 1;
+  }
+  return Reversed;
+}
+
+class ValueBits {
+  // This is a representation of a value's bits in terms of references to
+  // other values' bits, or 1/0 if the bit is known. This allows symbolic
+  // execution of bitwise instructions without knowing the exact values.
+  //
+  // Example:
+  //
+  // LLVM IR Value i8 %x:
+  // [%x[7], %x[6], %x[5], %x[4], %x[3], %x[2], %x[1], %x[0]]
+  //
+  // %shr = lshr i8 %x, 2
+  // [ 0, 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2]]
+  //
+  // %shl = shl i8 %shr, 1
+  // [ 0, %x[7], %x[6], %x[5], %x[4], %x[3], %x[2], 0]
+  //
+  // %xor = xor i8 %shl, 0xb
+  // [ 0, %x[7], %x[6], %x[5], %x[4]^1, %x[3], %x[2]^1, 1]
+public:
+  class ValueBit {
+  public:
+    enum BitType { ONE, ZERO, REF, XOR };
+
+  private:
+    BitType _Type;
+    std::pair<Value *, uint64_t> _BitRef;
+    ValueBit *_LHS;
+    ValueBit *_RHS;
+
+    ValueBit(BitType Type) : _Type(Type) {}
+    ValueBit(BitType Type, std::pair<Value *, uint64_t> BitRef)
+        : _Type(Type), _BitRef(BitRef) {}
+    ValueBit(BitType Type, ValueBit *LHS, ValueBit *RHS)
+        : _Type(Type), _LHS(LHS), _RHS(RHS) {}
+
+  public:
+    static ValueBit *CreateOneBit() { return new ValueBit(BitType::ONE); }
+    static ValueBit *CreateZeroBit() { return new ValueBit(BitType::ZERO); }
+    static ValueBit *CreateRefBit(Value *Ref, uint64_t Offset) {
+      return new ValueBit(BitType::REF, std::make_pair(Ref, Offset));
+    }
+    static ValueBit *CreateXORBit(ValueBit *LHS, ValueBit *RHS) {
+      return new ValueBit(BitType::XOR, LHS, RHS);
+    }
+    inline BitType getType() { return _Type; }
+    bool equals(ValueBit *RHS) {
+      if (_Type != RHS->getType())
+        return false;
+      switch (_Type) {
+      case BitType::ONE:
+      case BitType::ZERO:
+        return true;
+      case BitType::REF:
+        return _BitRef == RHS->_BitRef;
+      case BitType::XOR:
+        return (_LHS->equals(RHS->_LHS) && _RHS->equals(RHS->_RHS)) ||
+               (_LHS->equals(RHS->_RHS) && _RHS->equals(RHS->_LHS));
+      }
+      return false;
+    }
+
+    void print(raw_ostream &OS) {
+      switch (_Type) {
+      case BitType::ONE:
+        OS << "1";
+        break;
+      case BitType::ZERO:
+        OS << "0";
+        break;
+      case BitType::REF:
+        OS << _BitRef.first->getNameOrAsOperand() << "[" << _BitRef.second
+           << "]";
+        break;
+      case BitType::XOR:
+        _LHS->print(OS);
+        OS << "^";
+        _RHS->print(OS);
+        break;
+      }
+    }
+  };
+
+private:
+  uint64_t Size;
+  std::vector<ValueBit *> Bits;
+
+  virtual void _Shl(uint64_t N) {
+    for (; N > 0; N--) {
+      Bits.insert(Bits.begin(), ValueBit::CreateZeroBit());
+      Bits.pop_back();
+    }
+  }
+  virtual void _LShr(uint64_t N) {
+    for (; N > 0; N--) {
+      Bits.insert(Bits.end(), ValueBit::CreateZeroBit());
+      Bits.erase(Bits.begin());
+    }
+  }
+  virtual void _Xor(ValueBits *RHS) {
+    assert(Size == RHS->getSize());
+    for (unsigned I = 0; I < Size; I++) {
+      auto It = Bits.begin() + I;
+      ValueBit *RHSBit = RHS->getBit(I);
+      if (RHSBit->getType() == ValueBit::BitType::ONE) {
+        Bits.erase(It);
+        if ((*It)->getType() == ValueBit::BitType::ZERO) {
+          Bits.insert(It, ValueBit::CreateOneBit());
+        } else if ((*It)->getType() == ValueBit::BitType::ONE) {
+          Bits.insert(It, ValueBit::CreateZeroBit());
+        } else {
+          ValueBit *One = ValueBit::CreateOneBit();
+          Bits.insert(It, ValueBit::CreateXORBit(*It, One));
+        }
+      } else if (RHSBit->getType() != ValueBit::BitType::ZERO) {
+        if ((*It)->getType() == ValueBit::BitType::ZERO) {
+          Bits.erase(It);
+          ValueBit *BitRef = new ValueBit(*RHSBit);
+          Bits.insert(It, BitRef);
+        } else {
+          ValueBit *ItVB = *It;
+          Bits.erase(It);
+          Bits.insert(It, ValueBit::CreateXORBit(ItVB, RHSBit));
+        }
+      }
+    }
+  }
+  virtual void _ZExt(uint64_t ToSize) {
+    assert(ToSize > Size);
+    for (uint64_t I = 0; I < ToSize - Size; I++)
+      Bits.push_back(ValueBit::CreateZeroBit());
+    Size = ToSize;
+  }
+  virtual void _Trunc(uint64_t ToSize) {
+    assert(ToSize < Size);
+    Bits.erase(Bits.begin() + ToSize, Bits.end());
+    Size = ToSize;
+  }
+  virtual void _And(uint64_t RHS) {
+    for (unsigned I = 0; I < Size; I++) {
+      if (!(RHS & 1)) {
+        auto It = Bits.begin() + I;
+        Bits.erase(It);
+        Bits.insert(It, ValueBit::CreateZeroBit());
+      }
+      RHS >>= 1;
+    }
+  }
+
+protected:
+  ValueBits() {}
+
+public:
+  ValueBits(Value *InitialVal, uint64_t BitLength) : Size(BitLength) {
+    for (unsigned i = 0; i < BitLength; i++)
+      Bits.push_back(ValueBit::CreateRefBit(InitialVal, i));
+  }
+  ValueBits(uint64_t InitialVal, uint64_t BitLength) : Size(BitLength) {
+    for (unsigned i = 0; i < BitLength; i++) {
+      if (InitialVal & 0x1)
+        Bits.push_back(ValueBit::CreateOneBit());
+      else
+        Bits.push_back(ValueBit::CreateZeroBit());
+      InitialVal >>= 1;
+    }
+  }
+  uint64_t getSize() { return Size; }
+  ValueBit *getBit(unsigned i) { return Bits[i]; }
+
+  virtual ValueBits *copyBits() { return new ValueBits(*this); }
+
+  static ValueBits *Shl(ValueBits *LHS, uint64_t N) {
+    ValueBits *Shifted = LHS->copyBits();
+    Shifted->_Shl(N);
+    return Shifted;
+  }
+  static ValueBits *LShr(ValueBits *LHS, uint64_t N) {
+    ValueBits *Shifted = LHS->copyBits();
+    Shifted->_LShr(N);
+    return Shifted;
+  }
+  static ValueBits *Xor(ValueBits *LHS, ValueBits *RHS) {
+    ValueBits *Xord = LHS->copyBits();
+    Xord->_Xor(RHS);
+    return Xord;
+  }
+  static ValueBits *ZExt(ValueBits *LHS, uint64_t ToSize) {
+    ValueBits *Zexted = LHS->copyBits();
+    Zexted->_ZExt(ToSize);
+    return Zexted;
+  }
+  static ValueBits *Trunc(ValueBits *LHS, uint64_t N) {
+    ValueBits *Trunced = LHS->copyBits();
+    Trunced->_Trunc(N);
+    return Trunced;
+  }
+  static ValueBits *And(ValueBits *LHS, uint64_t RHS) {
+    ValueBits *Anded = LHS->copyBits();
+    Anded->_And(RHS);
+    return Anded;
+  }
+
+  virtual bool isPredicated() { return false; }
+
+  virtual bool equals(ValueBits *RHS) {
+    if (Size != RHS->getSize())
+      return false;
+
+    for (unsigned I = 0; I < Size; I++)
+      if (!getBit(I)->equals(RHS->getBit(I)))
+        return false;
+
+    return true;
+  }
+
+  virtual void print(raw_ostream &OS) {
+    assert(Size != 0);
+    OS << "[";
+    Bits[Size - 1]->print(OS);
+    for (int i = Size - 2; i >= 0; i--) {
+      OS << " | ";
+      Bits[i]->print(OS);
+    }
+    OS << "]\n";
+  }
+};
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits &VBS) {
+  VBS.print(OS);
+  return OS;
+}
+
+inline raw_ostream &operator<<(raw_ostream &OS, ValueBits::ValueBit &VB) {
+  VB.print(OS);
+  return OS;
+}
+class PredicatedValueBits : public ValueBits {
+  // This would be representitive of select or phi instructions where the bits
+  // would depend on an icmp.
+private:
+  ICmpInst *_Predicate;
+  ValueBits *_IfTrue;
+  ValueBits *_IfFalse;
+
+  void _Shl(uint64_t N) override {
+    _IfTrue = ValueBits::Shl(_IfTrue, N);
+    _IfFalse = ValueBits::Shl(_IfFalse, N);
+  }
+  void _LShr(uint64_t N) override {
+    _IfTrue = ValueBits::LShr(_IfTrue, N);
+    _IfFalse = ValueBits::LShr(_IfFalse, N);
+  }
+  void _ZExt(uint64_t N) override {
+    _IfTrue = ValueBits::ZExt(_IfTrue, N);
+    _IfFalse = ValueBits::ZExt(_IfFalse, N);
+  }
+  void _And(uint64_t N) override {
+    _IfTrue = ValueBits::And(_IfTrue, N);
+    _IfFalse = ValueBits::And(_IfFalse, N);
+  }
+  void _Xor(ValueBits *RHS) override {
+    _IfTrue = ValueBits::Xor(_IfTrue, RHS);
+    _IfFalse = ValueBits::Xor(_IfFalse, RHS);
+  }
+  void _Trunc(uint64_t N) override {
+    _IfTrue = ValueBits::Trunc(_IfTrue, N);
+    _IfFalse = ValueBits::Trunc(_IfFalse, N);
+  }
+
+public:
+  PredicatedValueBits(ICmpInst *Predicate, ValueBits *IfTrue,
+                      ValueBits *IfFalse)
+      : _Predicate(Predicate), _IfTrue(IfTrue), _IfFalse(IfFalse) {}
+
+  ValueBits *copyBits() override { return new PredicatedValueBits(*this); }
+  bool isPredicated() override { return true; }
+  ValueBits *getIfTrue() { return _IfTrue; }
+  ValueBits *getIfFalse() { return _IfFalse; }
+  ICmpInst *getPredicate() { return _Predicate; }
+
+  virtual void print(raw_ostream &OS) override {
+    OS << "Predicate: " << *_Predicate << "\nIf True:\n"
+       << *_IfTrue << "If False:\n"
+       << *_IfFalse;
+  }
+};
+
+// Execute the instructions in a basic block whilst mapping out Values to
+// ValueBits
+static bool symbolicallyExecute(BasicBlock *BB,
+                                std::map<Value *, ValueBits *> &ValueMap) {
+
+  auto getConstantOperand = [](Instruction *I, uint8_t Operand) {
+    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(Operand));
+    if (!CI) {
+      LLVM_DEBUG(dbgs() << DEBUG_TYPE " CRCRegonize: Do not know how to"
+                        << " handle this operation with non-constant operand "
+                        << Operand << ":\n"
+                        << *I << "\n");
+    }
+    return CI;
+  };
+
+  auto getOrCreateValueBits = [&ValueMap](Value *Val) {
+    auto Result = ValueMap.find(Val);
+    ValueBits *LHSBits = nullptr;
+    if (Result == ValueMap.end()) {
+      ConstantInt *CI = dyn_cast<ConstantInt>(Val);
+      if (CI) {
+        LHSBits = new ValueBits(CI->getSExtValue(),
+                                Val->getType()->getScalarSizeInBits());
+      } else {
+        LHSBits = new ValueBits(Val, Val->getType()->getScalarSizeInBits());
+      }
+    } else
+      LHSBits = Result->second;
+    return LHSBits;
+  };
+
+  for (Instruction &I : *BB) {
+    uint64_t BitSize = I.getType()->getScalarSizeInBits();
+    switch (I.getOpcode()) {
+    case Instruction::PHI: {
+      PHINode *PHI = dyn_cast<PHINode>(&I);
+      const BasicBlock *IncomingBlock = nullptr;
+      for (const BasicBlock *Incoming : PHI->blocks()) {
+        if (Incoming != BB) {
+          if (IncomingBlock) {
+            LLVM_DEBUG(dbgs()
+                       << DEBUG_TYPE " CRCRegonize: Do not know how to"
----------------
topperc wrote:

Recognize*. This spelling error is repeated in many messages.

https://github.com/llvm/llvm-project/pull/79295


More information about the llvm-commits mailing list