[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #127450)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 24 06:40:33 PST 2025


================
@@ -810,6 +810,267 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
 
 namespace {
 
+/// Base class for representing instructions that can be interchanged with other
+/// equivalent forms. For example, multiplication by a power of 2 can be
+/// interchanged with a left shift.
+///
+/// Derived classes implement specific interchange patterns by overriding the
+/// virtual methods to define their interchange logic.
+///
+/// The class maintains a reference to the main instruction (MainOp) and
+/// provides methods to:
+/// - Check if another instruction is interchangeable (isSame)
+/// - Get the opcode for the interchangeable form (getOpcode)
+/// - Get the operands for the interchangeable form (getOperand)
+class InterchangeableInstruction {
+protected:
+  Instruction *const MainOp;
+
+public:
+  InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
+  virtual bool isSame(Instruction *I) {
+    return MainOp->getOpcode() == I->getOpcode();
+  }
+  virtual unsigned getOpcode() { return MainOp->getOpcode(); }
+  virtual SmallVector<Value *> getOperand(Instruction *I) {
+    assert(MainOp->getOpcode() == I->getOpcode());
+    return SmallVector<Value *>(MainOp->operands());
+  }
+  virtual ~InterchangeableInstruction() = default;
+};
+
+class InterchangeableBinOp final : public InterchangeableInstruction {
+  using MaskType = std::uint_fast8_t;
+  constexpr static std::initializer_list<unsigned> SupportedOp = {
+      Instruction::Add,  Instruction::Sub, Instruction::Mul, Instruction::Shl,
+      Instruction::AShr, Instruction::And, Instruction::Or,  Instruction::Xor};
+  // from high to low bit: Xor Or And Sub Add Mul AShr Shl
+  MaskType Mask = 0b11111111;
+  MaskType SeenBefore = 0;
+
+  /// Return a non-nullptr if either operand of I is a ConstantInt.
+  static std::pair<ConstantInt *, unsigned>
+  isBinOpWithConstantInt(Instruction *I) {
+    unsigned Opcode = I->getOpcode();
+    unsigned Pos = 1;
+    Constant *C;
+    if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
+      if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+          Opcode == Instruction::AShr)
+        return std::make_pair(nullptr, Pos);
+      if (!match(I, m_BinOp(m_Constant(C), m_Value())))
+        return std::make_pair(nullptr, Pos);
+      Pos = 0;
+    }
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return std::make_pair(CI, Pos);
+    if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+      if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
+        return std::make_pair(CI, Pos);
+    }
+    return std::make_pair(nullptr, Pos);
+  }
+
+  static MaskType opcodeToMask(unsigned Opcode) {
+    switch (Opcode) {
+    case Instruction::Shl:
+      return 0b1;
+    case Instruction::AShr:
+      return 0b10;
+    case Instruction::Mul:
+      return 0b100;
+    case Instruction::Add:
+      return 0b1000;
+    case Instruction::Sub:
+      return 0b10000;
+    case Instruction::And:
+      return 0b100000;
+    case Instruction::Or:
+      return 0b1000000;
+    case Instruction::Xor:
+      return 0b10000000;
+    }
+    llvm_unreachable("Unsupported opcode.");
+  }
+
+  bool tryAnd(MaskType X) {
+    if (Mask & X) {
+      Mask &= X;
+      return true;
+    }
+    return false;
+  }
+
+public:
+  using InterchangeableInstruction::InterchangeableInstruction;
+  bool isSame(Instruction *I) override {
+    unsigned Opcode = I->getOpcode();
+    if (!binary_search(SupportedOp, Opcode))
+      return false;
+    SeenBefore |= opcodeToMask(Opcode);
+    ConstantInt *CI = isBinOpWithConstantInt(I).first;
+    if (CI) {
+      const APInt &CIValue = CI->getValue();
+      switch (Opcode) {
+      case Instruction::Shl:
+        if (CIValue.isZero())
+          return true;
+        return tryAnd(0b101);
+      case Instruction::Mul:
+        if (CIValue.isOne())
+          return true;
+        if (CIValue.isPowerOf2())
+          return tryAnd(0b101);
+        break;
+      case Instruction::And:
+        if (CIValue.isAllOnes())
+          return true;
+        break;
+      default:
+        if (CIValue.isZero())
+          return true;
+        break;
+      }
+    }
+    return tryAnd(opcodeToMask(Opcode));
+  }
+  unsigned getOpcode() override {
+    MaskType Candidate = Mask & SeenBefore;
+    if (Candidate & 0b1)
+      return Instruction::Shl;
+    if (Candidate & 0b10)
+      return Instruction::AShr;
+    if (Candidate & 0b100)
+      return Instruction::Mul;
+    if (Candidate & 0b1000)
+      return Instruction::Add;
+    if (Candidate & 0b10000)
+      return Instruction::Sub;
+    if (Candidate & 0b100000)
+      return Instruction::And;
+    if (Candidate & 0b1000000)
+      return Instruction::Or;
+    if (Candidate & 0b10000000)
+      return Instruction::Xor;
+    llvm_unreachable("Cannot find interchangeable instruction.");
+  }
+  SmallVector<Value *> getOperand(Instruction *I) override {
+    unsigned ToOpcode = I->getOpcode();
+    assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
+    unsigned FromOpcode = MainOp->getOpcode();
+    if (FromOpcode == ToOpcode)
+      return SmallVector<Value *>(MainOp->operands());
+    auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
+    const APInt &FromCIValue = CI->getValue();
+    unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
+    APInt ToCIValue;
+    switch (FromOpcode) {
+    case Instruction::Shl:
+      if (ToOpcode == Instruction::Mul) {
+        ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
+                                        FromCIValue.getZExtValue());
+      } else {
+        assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+        ToCIValue = ToOpcode == Instruction::And
+                        ? APInt::getAllOnes(FromCIValueBitWidth)
+                        : APInt::getZero(FromCIValueBitWidth);
+      }
+      break;
+    case Instruction::Mul:
+      assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
+      if (ToOpcode == Instruction::Shl) {
+        ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
+      } else {
+        assert(FromCIValue.isOne() && "Cannot convert the instruction.");
+        ToCIValue = ToOpcode == Instruction::And
+                        ? APInt::getAllOnes(FromCIValueBitWidth)
+                        : APInt::getZero(FromCIValueBitWidth);
+      }
+      break;
+    case Instruction::And:
+      assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
+      ToCIValue = ToOpcode == Instruction::Mul
+                      ? APInt::getOneBitSet(FromCIValueBitWidth, 0)
+                      : APInt::getZero(FromCIValueBitWidth);
+      break;
+    default:
+      ToCIValue = APInt::getZero(FromCIValueBitWidth);
+      break;
+    }
+    auto LHS = MainOp->getOperand(1 - Pos);
+    auto RHS = ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
+    if (Pos == 1)
+      return SmallVector<Value *>({LHS, RHS});
+    return SmallVector<Value *>({RHS, LHS});
+  }
+};
+
+static SmallVector<std::unique_ptr<InterchangeableInstruction>>
+getInterchangeableInstruction(Instruction *MainOp) {
+  SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate;
+  Candidate.push_back(std::make_unique<InterchangeableInstruction>(MainOp));
+  if (MainOp->isBinaryOp())
+    Candidate.push_back(std::make_unique<InterchangeableBinOp>(MainOp));
+  return Candidate;
+}
+
+static bool getInterchangeableInstruction(
+    SmallVector<std::unique_ptr<InterchangeableInstruction>> &Candidate,
+    Instruction *I) {
+  auto Iter = std::stable_partition(
+      Candidate.begin(), Candidate.end(),
+      [&](const std::unique_ptr<InterchangeableInstruction> &C) {
+        return C->isSame(I);
+      });
+  if (Iter == Candidate.begin())
+    return false;
+  Candidate.erase(Iter, Candidate.end());
+  return true;
+}
+
+static bool isConvertible(Instruction *I, Instruction *MainOp,
+                          Instruction *AltOp) {
+  assert(MainOp && "MainOp cannot be nullptr.");
+  if (I->getOpcode() == MainOp->getOpcode())
+    return true;
+  assert(AltOp && "AltOp cannot be nullptr.");
+  if (I->getOpcode() == AltOp->getOpcode())
+    return true;
+  if (!I->isBinaryOp())
+    return false;
+  SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
+      getInterchangeableInstruction(I));
+  for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+    if (C->isSame(I) && C->isSame(MainOp))
+      return true;
+  Candidate = getInterchangeableInstruction(I);
+  for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+    if (C->isSame(I) && C->isSame(AltOp))
+      return true;
+  return false;
+}
+
+static std::pair<Instruction *, SmallVector<Value *>>
+convertTo(Instruction *I, Instruction *MainOp, Instruction *AltOp) {
+  assert(isConvertible(I, MainOp, AltOp) && "Cannot convert the instruction.");
+  if (I->getOpcode() == MainOp->getOpcode())
+    return std::make_pair(MainOp, SmallVector<Value *>(I->operands()));
+  // Prefer AltOp instead of interchangeable instruction of MainOp.
+  if (I->getOpcode() == AltOp->getOpcode())
+    return std::make_pair(AltOp, SmallVector<Value *>(I->operands()));
+  assert(I->isBinaryOp() && "Cannot convert the instruction.");
+  SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
+      getInterchangeableInstruction(I));
+  for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+    if (C->isSame(I) && C->isSame(MainOp))
+      return std::make_pair(MainOp, C->getOperand(MainOp));
+  Candidate = getInterchangeableInstruction(I);
+  for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+    if (C->isSame(I) && C->isSame(AltOp))
+      return std::make_pair(AltOp, C->getOperand(AltOp));
----------------
alexey-bataev wrote:

Still do not understand why you can do selection immediately during getInterchangeableInstruction and need an extra filtering. Do the filtering during lookup to reduce number of candidates

https://github.com/llvm/llvm-project/pull/127450


More information about the llvm-commits mailing list