[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #127450)
Han-Kuan Chen via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 5 10:40:36 PST 2025
================
@@ -810,11 +810,249 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
namespace {
+/// Base class for representing instructions that can be interchanged with other
+/// equivalent forms. For example, multiplication by a power of 2 can be
+/// interchanged with a left shift.
+///
+/// The class maintains a reference to the main instruction (MainOp) and
+/// provides methods to:
+/// - Check if the incoming instruction can use the same instruction as MainOp
+/// (add)
+/// - Get the opcode for the interchangeable form (getOpcode)
+/// - Get the operands for the interchangeable form (getOperand)
+class InterchangeableBinOp {
+ using MaskType = std::uint_fast16_t;
+ // Sort SupportedOp because it is used by binary_search.
+ constexpr static std::initializer_list<unsigned> SupportedOp = {
+ Instruction::Add, Instruction::Sub, Instruction::Mul, Instruction::Shl,
+ Instruction::AShr, Instruction::And, Instruction::Or, Instruction::Xor};
+ enum : MaskType {
+ NOBIT = 0,
+ ShlBIT = 0b1,
+ AShrBIT = 0b10,
+ MulBIT = 0b100,
+ AddBIT = 0b1000,
+ SubBIT = 0b10000,
+ AndBIT = 0b100000,
+ OrBIT = 0b1000000,
+ XorBIT = 0b10000000,
+ MainOpBIT = 0b100000000,
+ LLVM_MARK_AS_BITMASK_ENUM(MainOpBIT)
+ };
+ Instruction *MainOp = nullptr;
+ // The bit it sets represents whether MainOp can be converted to.
+ MaskType Mask = MainOpBIT | XorBIT | OrBIT | AndBIT | SubBIT | AddBIT |
+ MulBIT | AShrBIT | ShlBIT;
+ // We cannot create an interchangeable instruction that does not exist in VL.
+ // For example, VL [x + 0, y * 1] can be converted to [x << 0, y << 0], but
+ // 'shl' does not exist in VL. In the end, we convert VL to [x * 1, y * 1].
+ // SeenBefore is used to know what operations have been seen before.
+ MaskType SeenBefore = 0;
+
+ // Return a non-nullptr if either operand of I is a ConstantInt.
+ // The second return value represents the operand position. We check the
+ // right-hand side first (1). If the right hand side is not a ConstantInt and
+ // the instruction is neither Sub, Shl, nor AShr, we then check the left hand
+ // side (0).
+ static std::pair<ConstantInt *, unsigned>
+ isBinOpWithConstantInt(Instruction *I) {
+ unsigned Opcode = I->getOpcode();
+ assert(binary_search(SupportedOp, Opcode) && "Unsupported opcode.");
+ unsigned Pos = 1;
+ Constant *C;
+ if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
+ if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+ Opcode == Instruction::AShr)
+ return std::make_pair(nullptr, Pos);
+ if (!match(I, m_BinOp(m_Constant(C), m_Value())))
+ return std::make_pair(nullptr, Pos);
+ Pos = 0;
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(C))
+ return std::make_pair(CI, Pos);
+ if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
+ return std::make_pair(CI, Pos);
+ }
+ return std::make_pair(nullptr, Pos);
+ }
+
+ // Prefer Shl, AShr, Mul, Add, Sub, And, Or and Xor over MainOp.
+ MaskType opcodeToMask(unsigned Opcode) const {
+ switch (Opcode) {
+ case Instruction::Shl:
+ return ShlBIT;
+ case Instruction::AShr:
+ return AShrBIT;
+ case Instruction::Mul:
+ return MulBIT;
+ case Instruction::Add:
+ return AddBIT;
+ case Instruction::Sub:
+ return SubBIT;
+ case Instruction::And:
+ return AndBIT;
+ case Instruction::Or:
+ return OrBIT;
+ case Instruction::Xor:
+ return XorBIT;
+ }
+ return Opcode == MainOp->getOpcode() ? MainOpBIT : NOBIT;
+ }
+
+ MaskType getInterchangeableMask(Instruction *I) const {
+ unsigned Opcode = I->getOpcode();
+ if (!binary_search(SupportedOp, Opcode))
+ return opcodeToMask(Opcode);
+ ConstantInt *CI = isBinOpWithConstantInt(I).first;
+ if (CI) {
+ constexpr MaskType CanBeAll =
+ XorBIT | OrBIT | AndBIT | SubBIT | AddBIT | MulBIT | AShrBIT | ShlBIT;
+ const APInt &CIValue = CI->getValue();
+ switch (Opcode) {
+ case Instruction::Shl:
+ if (CIValue.isZero())
+ return CanBeAll;
+ return MulBIT | ShlBIT;
+ case Instruction::Mul:
+ if (CIValue.isOne())
+ return CanBeAll;
+ if (CIValue.isPowerOf2())
+ return MulBIT | ShlBIT;
+ break;
+ case Instruction::Add:
+ case Instruction::Sub:
+ if (CIValue.isZero())
+ return CanBeAll;
+ return SubBIT | AddBIT;
+ case Instruction::And:
+ if (CIValue.isAllOnes())
+ return CanBeAll;
+ break;
+ default:
+ if (CIValue.isZero())
+ return CanBeAll;
+ break;
+ }
+ }
+ return opcodeToMask(Opcode);
+ }
+
+ // Return false allows getSameOpcode to find an alternate instruction.
+ // Directly setting the mask will destroy the mask state, preventing us from
+ // determining which instruction the MainOp should convert to.
+ bool trySet(MaskType X) {
+ if (Mask & X) {
+ Mask &= X;
+ return true;
+ }
+ return false;
+ }
+
+public:
+ InterchangeableBinOp(Instruction *MainOp) : MainOp(MainOp) {
+ assert(is_sorted(SupportedOp) && "SupportedOp is not sorted.");
+ }
+ bool add(Instruction *I) {
+ SeenBefore |= opcodeToMask(I->getOpcode());
+ return trySet(getInterchangeableMask(I));
+ }
+ bool contain(Instruction *I) const {
+ return Mask & getInterchangeableMask(I);
+ }
+ unsigned getOpcode() const {
+ MaskType Candidate = Mask & SeenBefore;
+ if (Candidate & MainOpBIT)
+ return MainOp->getOpcode();
+ if (Candidate & ShlBIT)
+ return Instruction::Shl;
+ if (Candidate & AShrBIT)
+ return Instruction::AShr;
+ if (Candidate & MulBIT)
+ return Instruction::Mul;
+ if (Candidate & AddBIT)
+ return Instruction::Add;
+ if (Candidate & SubBIT)
+ return Instruction::Sub;
+ if (Candidate & AndBIT)
+ return Instruction::And;
+ if (Candidate & OrBIT)
+ return Instruction::Or;
+ if (Candidate & XorBIT)
+ return Instruction::Xor;
+ llvm_unreachable("Cannot find interchangeable instruction.");
+ }
+ SmallVector<Value *> getOperand(Instruction *I) const {
+ unsigned ToOpcode = I->getOpcode();
+ unsigned FromOpcode = MainOp->getOpcode();
+ if (FromOpcode == ToOpcode)
+ return SmallVector<Value *>(MainOp->operands());
+ assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
+ auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
+ const APInt &FromCIValue = CI->getValue();
+ unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
+ APInt ToCIValue;
+ switch (FromOpcode) {
+ case Instruction::Shl:
+ if (ToOpcode == Instruction::Mul) {
+ ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
+ FromCIValue.getZExtValue());
+ } else {
+ assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::And
+ ? APInt::getAllOnes(FromCIValueBitWidth)
+ : APInt::getZero(FromCIValueBitWidth);
+ }
+ break;
+ case Instruction::Mul:
+ assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
+ if (ToOpcode == Instruction::Shl) {
+ ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
+ } else {
+ assert(FromCIValue.isOne() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::And
+ ? APInt::getAllOnes(FromCIValueBitWidth)
+ : APInt::getZero(FromCIValueBitWidth);
+ }
+ break;
+ case Instruction::Add:
+ case Instruction::Sub:
+ if (FromCIValue.isZero()) {
+ ToCIValue = APInt::getZero(FromCIValueBitWidth);
+ } else {
+ assert(is_contained({Instruction::Add, Instruction::Sub}, ToOpcode) &&
+ "Cannot convert the instruction.");
+ ToCIValue = FromCIValue;
+ ToCIValue.negate();
+ }
+ break;
+ case Instruction::And:
+ assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
+ ToCIValue = ToOpcode == Instruction::Mul
+ ? APInt::getOneBitSet(FromCIValueBitWidth, 0)
+ : APInt::getZero(FromCIValueBitWidth);
+ break;
+ default:
+ ToCIValue = APInt::getZero(FromCIValueBitWidth);
+ break;
+ }
+ Value *LHS = MainOp->getOperand(1 - Pos);
+ Constant *RHS =
+ ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
+ if (Pos == 1)
+ return SmallVector<Value *>({LHS, RHS});
+ return SmallVector<Value *>({RHS, LHS});
+ }
+};
+
/// Main data required for vectorization of instructions.
class InstructionsState {
/// The main/alternate instruction. MainOp is also VL0.
Instruction *MainOp = nullptr;
Instruction *AltOp = nullptr;
+ // Only BinaryOperator will activate this.
+ std::optional<InterchangeableBinOp> MainOpConverter;
+ std::optional<InterchangeableBinOp> AltOpConverter;
----------------
HanKuanChen wrote:
Revert ad7bec92d18a7bc908d6717bdb49b69843ab323c and ddcd456f749f1c4b12f7b96aa7c52197c0ac201a to remove the data members.
https://github.com/llvm/llvm-project/pull/127450
More information about the llvm-commits
mailing list