[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #127450)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 20 11:37:54 PST 2025
================
@@ -810,6 +810,267 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
namespace {
+/// Base class for representing instructions that can be interchanged with other
+/// equivalent forms. For example, multiplication by a power of 2 can be
+/// interchanged with a left shift.
+///
+/// Derived classes implement specific interchange patterns by overriding the
+/// virtual methods to define their interchange logic.
+///
+/// The class maintains a reference to the main instruction (MainOp) and
+/// provides methods to:
+/// - Check if another instruction is interchangeable (isSame)
+/// - Get the opcode for the interchangeable form (getOpcode)
+/// - Get the operands for the interchangeable form (getOperand)
+class InterchangeableInstruction {
+protected:
+ Instruction *const MainOp;
+
+public:
+ InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
+ virtual bool isSame(Instruction *I) {
+ return MainOp->getOpcode() == I->getOpcode();
+ }
+ virtual unsigned getOpcode() { return MainOp->getOpcode(); }
+ virtual SmallVector<Value *> getOperand(Instruction *I) {
+ assert(MainOp->getOpcode() == I->getOpcode());
+ return SmallVector<Value *>(MainOp->operands());
+ }
+ virtual ~InterchangeableInstruction() = default;
+};
+
+class InterchangeableBinOp final : public InterchangeableInstruction {
+ using MaskType = std::uint_fast8_t;
+ constexpr static std::initializer_list<unsigned> SupportedOp = {
+ Instruction::Add, Instruction::Sub, Instruction::Mul, Instruction::Shl,
+ Instruction::AShr, Instruction::And, Instruction::Or, Instruction::Xor};
+ // from high to low bit: Xor Or And Sub Add Mul AShr Shl
+ MaskType Mask = 0b11111111;
+ MaskType SeenBefore = 0;
+
+ /// Return a non-nullptr if either operand of I is a ConstantInt.
+ static std::pair<ConstantInt *, unsigned>
+ isBinOpWithConstantInt(Instruction *I) {
+ unsigned Opcode = I->getOpcode();
+ unsigned Pos = 1;
+ Constant *C;
+ if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
+ if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+ Opcode == Instruction::AShr)
+ return std::make_pair(nullptr, Pos);
+ if (!match(I, m_BinOp(m_Constant(C), m_Value())))
+ return std::make_pair(nullptr, Pos);
+ Pos = 0;
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(C))
+ return std::make_pair(CI, Pos);
+ if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
+ return std::make_pair(CI, Pos);
+ }
+ return std::make_pair(nullptr, Pos);
+ }
+
+ static MaskType opcodeToMask(unsigned Opcode) {
+ switch (Opcode) {
+ case Instruction::Shl:
+ return 0b1;
+ case Instruction::AShr:
+ return 0b10;
+ case Instruction::Mul:
+ return 0b100;
+ case Instruction::Add:
+ return 0b1000;
+ case Instruction::Sub:
+ return 0b10000;
+ case Instruction::And:
+ return 0b100000;
+ case Instruction::Or:
+ return 0b1000000;
+ case Instruction::Xor:
+ return 0b10000000;
+ }
+ llvm_unreachable("Unsupported opcode.");
+ }
+
+ bool tryAnd(MaskType X) {
+ if (Mask & X) {
+ Mask &= X;
+ return true;
+ }
+ return false;
+ }
+
+public:
+ using InterchangeableInstruction::InterchangeableInstruction;
+ bool isSame(Instruction *I) override {
+ unsigned Opcode = I->getOpcode();
+ if (!binary_search(SupportedOp, Opcode))
+ return false;
+ SeenBefore |= opcodeToMask(Opcode);
+ ConstantInt *CI = isBinOpWithConstantInt(I).first;
+ if (CI) {
+ const APInt &CIValue = CI->getValue();
+ switch (Opcode) {
+ case Instruction::Shl:
+ if (CIValue.isZero())
+ return true;
+ return tryAnd(0b101);
+ case Instruction::Mul:
+ if (CIValue.isOne())
+ return true;
+ if (CIValue.isPowerOf2())
+ return tryAnd(0b101);
+ break;
+ case Instruction::And:
+ if (CIValue.isAllOnes())
+ return true;
+ break;
+ default:
+ if (CIValue.isZero())
+ return true;
+ break;
+ }
+ }
+ return tryAnd(opcodeToMask(Opcode));
+ }
+ unsigned getOpcode() override {
+ MaskType Candidate = Mask & SeenBefore;
+ if (Candidate & 0b1)
+ return Instruction::Shl;
+ if (Candidate & 0b10)
+ return Instruction::AShr;
+ if (Candidate & 0b100)
+ return Instruction::Mul;
+ if (Candidate & 0b1000)
+ return Instruction::Add;
+ if (Candidate & 0b10000)
+ return Instruction::Sub;
+ if (Candidate & 0b100000)
+ return Instruction::And;
+ if (Candidate & 0b1000000)
+ return Instruction::Or;
+ if (Candidate & 0b10000000)
+ return Instruction::Xor;
+ llvm_unreachable("Cannot find interchangeable instruction.");
+ }
+ SmallVector<Value *> getOperand(Instruction *I) override {
+ unsigned ToOpcode = I->getOpcode();
+ assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
+ unsigned FromOpcode = MainOp->getOpcode();
+ if (FromOpcode == ToOpcode)
+ return SmallVector<Value *>(MainOp->operands());
+ auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
+ const APInt &FromCIValue = CI->getValue();
+ unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
+ APInt ToCIValue;
+ switch (FromOpcode) {
+ case Instruction::Shl:
+ if (ToOpcode == Instruction::Mul) {
+ ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
+ FromCIValue.getZExtValue());
+ } else {
+ assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::And
+ ? APInt::getAllOnes(FromCIValueBitWidth)
+ : APInt::getZero(FromCIValueBitWidth);
+ }
+ break;
+ case Instruction::Mul:
+ assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
+ if (ToOpcode == Instruction::Shl) {
+ ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
+ } else {
+ assert(FromCIValue.isOne() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::And
+ ? APInt::getAllOnes(FromCIValueBitWidth)
+ : APInt::getZero(FromCIValueBitWidth);
+ }
+ break;
+ case Instruction::And:
+ assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::Mul
+ ? APInt::getOneBitSet(FromCIValueBitWidth, 0)
+ : APInt::getZero(FromCIValueBitWidth);
+ break;
+ default:
+ ToCIValue = APInt::getZero(FromCIValueBitWidth);
+ break;
+ }
+ auto LHS = MainOp->getOperand(1 - Pos);
+ auto RHS = ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
+ if (Pos == 1)
+ return SmallVector<Value *>({LHS, RHS});
+ return SmallVector<Value *>({RHS, LHS});
+ }
+};
+
+static SmallVector<std::unique_ptr<InterchangeableInstruction>>
+getInterchangeableInstruction(Instruction *MainOp) {
+ SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate;
+ Candidate.push_back(std::make_unique<InterchangeableInstruction>(MainOp));
+ if (MainOp->isBinaryOp())
+ Candidate.push_back(std::make_unique<InterchangeableBinOp>(MainOp));
+ return Candidate;
+}
+
+static bool getInterchangeableInstruction(
+ SmallVector<std::unique_ptr<InterchangeableInstruction>> &Candidate,
+ Instruction *I) {
+ auto Iter = std::stable_partition(
+ Candidate.begin(), Candidate.end(),
+ [&](const std::unique_ptr<InterchangeableInstruction> &C) {
+ return C->isSame(I);
+ });
+ if (Iter == Candidate.begin())
+ return false;
+ Candidate.erase(Iter, Candidate.end());
+ return true;
+}
+
+static bool isConvertible(Instruction *I, Instruction *MainOp,
+ Instruction *AltOp) {
+ assert(MainOp && "MainOp cannot be nullptr.");
+ if (I->getOpcode() == MainOp->getOpcode())
+ return true;
+ assert(AltOp && "AltOp cannot be nullptr.");
+ if (I->getOpcode() == AltOp->getOpcode())
+ return true;
+ if (!I->isBinaryOp())
+ return false;
+ SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
+ getInterchangeableInstruction(I));
+ for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+ if (C->isSame(I) && C->isSame(MainOp))
+ return true;
+ Candidate = getInterchangeableInstruction(I);
+ for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+ if (C->isSame(I) && C->isSame(AltOp))
+ return true;
+ return false;
+}
+
+static std::pair<Instruction *, SmallVector<Value *>>
+convertTo(Instruction *I, Instruction *MainOp, Instruction *AltOp) {
+ assert(isConvertible(I, MainOp, AltOp) && "Cannot convert the instruction.");
+ if (I->getOpcode() == MainOp->getOpcode())
+ return std::make_pair(MainOp, SmallVector<Value *>(I->operands()));
+ // Prefer AltOp instead of interchangeable instruction of MainOp.
+ if (I->getOpcode() == AltOp->getOpcode())
+ return std::make_pair(AltOp, SmallVector<Value *>(I->operands()));
+ assert(I->isBinaryOp() && "Cannot convert the instruction.");
+ SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
+ getInterchangeableInstruction(I));
+ for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+ if (C->isSame(I) && C->isSame(MainOp))
+ return std::make_pair(MainOp, C->getOperand(MainOp));
+ Candidate = getInterchangeableInstruction(I);
+ for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
+ if (C->isSame(I) && C->isSame(AltOp))
+ return std::make_pair(AltOp, C->getOperand(AltOp));
----------------
alexey-bataev wrote:
Why do you need this lookup and cannot return it as a result of `getInterchangeableInstruction`?
https://github.com/llvm/llvm-project/pull/127450
More information about the llvm-commits
mailing list