[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #127450)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 24 07:27:57 PST 2025
================
@@ -810,6 +810,285 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
namespace {
+/// Base class for representing instructions that can be interchanged with other
+/// equivalent forms. For example, multiplication by a power of 2 can be
+/// interchanged with a left shift.
+///
+/// Derived classes implement specific interchange patterns by overriding the
+/// virtual methods to define their interchange logic.
+///
+/// The class maintains a reference to the main instruction (MainOp) and
+/// provides methods to:
+/// - Check if another instruction is interchangeable (isSame)
+/// - Get the opcode for the interchangeable form (getOpcode)
+/// - Get the operands for the interchangeable form (getOperand)
+class InterchangeableInstruction {
+protected:
+ Instruction *const MainOp = nullptr;
+
+public:
+ InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
+ virtual bool isSame(Instruction *I) const {
+ return MainOp->getOpcode() == I->getOpcode();
+ }
+ virtual unsigned getOpcode() const { return MainOp->getOpcode(); }
+ virtual SmallVector<Value *> getOperand(Instruction *I) const {
+ assert(MainOp->getOpcode() == I->getOpcode() &&
+ "Cannot convert the instruction.");
+ return SmallVector<Value *>(MainOp->operands());
+ }
+ virtual ~InterchangeableInstruction() = default;
+};
+
+class InterchangeableBinOp final : public InterchangeableInstruction {
+ using MaskType = std::uint_fast8_t;
+ // Sort SupportedOp because it is used by binary_search.
+ constexpr static std::initializer_list<unsigned> SupportedOp = {
+ Instruction::Add, Instruction::Sub, Instruction::Mul, Instruction::Shl,
+ Instruction::AShr, Instruction::And, Instruction::Or, Instruction::Xor};
+ enum : MaskType {
+ SHL_BIT = 0b1,
+ AShr_BIT = 0b10,
+ Mul_BIT = 0b100,
+ Add_BIT = 0b1000,
+ Sub_BIT = 0b10000,
+ And_BIT = 0b100000,
+ Or_BIT = 0b1000000,
+ Xor_BIT = 0b10000000,
+ };
+ // The bit it sets represents whether MainOp can be converted to.
+ mutable MaskType Mask = Xor_BIT | Or_BIT | And_BIT | Sub_BIT | Add_BIT |
+ Mul_BIT | AShr_BIT | SHL_BIT;
+ // We cannot create an interchangeable instruction that does not exist in VL.
+ // For example, VL [x + 0, y * 1] can be converted to [x << 0, y << 0], but
+ // 'shl' does not exist in VL. In the end, we convert VL to [x * 1, y * 1].
+ // SeenBefore is used to know what operations have been seen before.
+ mutable MaskType SeenBefore = 0;
+
+ /// Return a non-nullptr if either operand of I is a ConstantInt.
+ static std::pair<ConstantInt *, unsigned>
+ isBinOpWithConstantInt(Instruction *I) {
+ unsigned Opcode = I->getOpcode();
+ unsigned Pos = 1;
+ Constant *C;
+ if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
+ if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+ Opcode == Instruction::AShr)
+ return std::make_pair(nullptr, Pos);
+ if (!match(I, m_BinOp(m_Constant(C), m_Value())))
+ return std::make_pair(nullptr, Pos);
+ Pos = 0;
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(C))
+ return std::make_pair(CI, Pos);
+ if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
+ return std::make_pair(CI, Pos);
+ }
+ return std::make_pair(nullptr, Pos);
+ }
+
+ static MaskType opcodeToMask(unsigned Opcode) {
+ switch (Opcode) {
+ case Instruction::Shl:
+ return SHL_BIT;
+ case Instruction::AShr:
+ return AShr_BIT;
+ case Instruction::Mul:
+ return Mul_BIT;
+ case Instruction::Add:
+ return Add_BIT;
+ case Instruction::Sub:
+ return Sub_BIT;
+ case Instruction::And:
+ return And_BIT;
+ case Instruction::Or:
+ return Or_BIT;
+ case Instruction::Xor:
+ return Xor_BIT;
+ }
+ llvm_unreachable("Unsupported opcode.");
+ }
+
+ bool tryAnd(MaskType X) const {
+ if (Mask & X) {
+ Mask &= X;
+ return true;
+ }
+ return false;
+ }
+
+public:
+ using InterchangeableInstruction::InterchangeableInstruction;
+ bool isSame(Instruction *I) const override {
+ unsigned Opcode = I->getOpcode();
+ if (!binary_search(SupportedOp, Opcode))
+ return false;
+ SeenBefore |= opcodeToMask(Opcode);
+ ConstantInt *CI = isBinOpWithConstantInt(I).first;
+ if (CI) {
+ const APInt &CIValue = CI->getValue();
+ switch (Opcode) {
+ case Instruction::Shl:
+ if (CIValue.isZero())
+ return true;
+ return tryAnd(Mul_BIT | SHL_BIT);
+ case Instruction::Mul:
+ if (CIValue.isOne())
+ return true;
+ if (CIValue.isPowerOf2())
+ return tryAnd(Mul_BIT | SHL_BIT);
+ break;
+ case Instruction::And:
+ if (CIValue.isAllOnes())
+ return true;
+ break;
+ default:
+ if (CIValue.isZero())
+ return true;
+ break;
+ }
+ }
+ return tryAnd(opcodeToMask(Opcode));
+ }
+ unsigned getOpcode() const override {
+ MaskType Candidate = Mask & SeenBefore;
+ if (Candidate & SHL_BIT)
+ return Instruction::Shl;
----------------
alexey-bataev wrote:
IF the mask contains shl, it is always first candidate. If the mask contains ashr, it is always the second candidate. Why do you need to keep all candidates (in mask), if you have strict order, which one should be chosen?
https://github.com/llvm/llvm-project/pull/127450
More information about the llvm-commits
mailing list