[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #127450)

Han-Kuan Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 5 10:40:36 PST 2025


================
@@ -810,11 +810,249 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) {
 
 namespace {
 
+/// Base class for representing instructions that can be interchanged with other
+/// equivalent forms. For example, multiplication by a power of 2 can be
+/// interchanged with a left shift.
+///
+/// The class maintains a reference to the main instruction (MainOp) and
+/// provides methods to:
+/// - Check if the incoming instruction can use the same instruction as MainOp
+/// (add)
+/// - Get the opcode for the interchangeable form (getOpcode)
+/// - Get the operands for the interchangeable form (getOperand)
+class InterchangeableBinOp {
+  using MaskType = std::uint_fast16_t;
+  // Sort SupportedOp because it is used by binary_search.
+  constexpr static std::initializer_list<unsigned> SupportedOp = {
+      Instruction::Add,  Instruction::Sub, Instruction::Mul, Instruction::Shl,
+      Instruction::AShr, Instruction::And, Instruction::Or,  Instruction::Xor};
+  enum : MaskType {
+    NOBIT = 0,
+    ShlBIT = 0b1,
+    AShrBIT = 0b10,
+    MulBIT = 0b100,
+    AddBIT = 0b1000,
+    SubBIT = 0b10000,
+    AndBIT = 0b100000,
+    OrBIT = 0b1000000,
+    XorBIT = 0b10000000,
+    MainOpBIT = 0b100000000,
+    LLVM_MARK_AS_BITMASK_ENUM(MainOpBIT)
+  };
+  Instruction *MainOp = nullptr;
+  // The bit it sets represents whether MainOp can be converted to.
+  MaskType Mask = MainOpBIT | XorBIT | OrBIT | AndBIT | SubBIT | AddBIT |
+                  MulBIT | AShrBIT | ShlBIT;
+  // We cannot create an interchangeable instruction that does not exist in VL.
+  // For example, VL [x + 0, y * 1] can be converted to [x << 0, y << 0], but
+  // 'shl' does not exist in VL. In the end, we convert VL to [x * 1, y * 1].
+  // SeenBefore is used to know what operations have been seen before.
+  MaskType SeenBefore = 0;
+
+  // Return a non-nullptr if either operand of I is a ConstantInt.
+  // The second return value represents the operand position. We check the
+  // right-hand side first (1). If the right hand side is not a ConstantInt and
+  // the instruction is neither Sub, Shl, nor AShr, we then check the left hand
+  // side (0).
+  static std::pair<ConstantInt *, unsigned>
+  isBinOpWithConstantInt(Instruction *I) {
+    unsigned Opcode = I->getOpcode();
+    assert(binary_search(SupportedOp, Opcode) && "Unsupported opcode.");
+    unsigned Pos = 1;
+    Constant *C;
+    if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
+      if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+          Opcode == Instruction::AShr)
+        return std::make_pair(nullptr, Pos);
+      if (!match(I, m_BinOp(m_Constant(C), m_Value())))
+        return std::make_pair(nullptr, Pos);
+      Pos = 0;
+    }
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return std::make_pair(CI, Pos);
+    if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+      if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
+        return std::make_pair(CI, Pos);
+    }
+    return std::make_pair(nullptr, Pos);
+  }
+
+  // Prefer Shl, AShr, Mul, Add, Sub, And, Or and Xor over MainOp.
+  MaskType opcodeToMask(unsigned Opcode) const {
+    switch (Opcode) {
+    case Instruction::Shl:
+      return ShlBIT;
+    case Instruction::AShr:
+      return AShrBIT;
+    case Instruction::Mul:
+      return MulBIT;
+    case Instruction::Add:
+      return AddBIT;
+    case Instruction::Sub:
+      return SubBIT;
+    case Instruction::And:
+      return AndBIT;
+    case Instruction::Or:
+      return OrBIT;
+    case Instruction::Xor:
+      return XorBIT;
+    }
+    return Opcode == MainOp->getOpcode() ? MainOpBIT : NOBIT;
+  }
+
+  MaskType getInterchangeableMask(Instruction *I) const {
+    unsigned Opcode = I->getOpcode();
+    if (!binary_search(SupportedOp, Opcode))
+      return opcodeToMask(Opcode);
+    ConstantInt *CI = isBinOpWithConstantInt(I).first;
+    if (CI) {
+      constexpr MaskType CanBeAll =
+          XorBIT | OrBIT | AndBIT | SubBIT | AddBIT | MulBIT | AShrBIT | ShlBIT;
+      const APInt &CIValue = CI->getValue();
+      switch (Opcode) {
+      case Instruction::Shl:
+        if (CIValue.isZero())
+          return CanBeAll;
+        return MulBIT | ShlBIT;
+      case Instruction::Mul:
+        if (CIValue.isOne())
+          return CanBeAll;
+        if (CIValue.isPowerOf2())
+          return MulBIT | ShlBIT;
+        break;
+      case Instruction::Add:
+      case Instruction::Sub:
+        if (CIValue.isZero())
+          return CanBeAll;
+        return SubBIT | AddBIT;
+      case Instruction::And:
+        if (CIValue.isAllOnes())
+          return CanBeAll;
+        break;
+      default:
+        if (CIValue.isZero())
+          return CanBeAll;
+        break;
+      }
+    }
+    return opcodeToMask(Opcode);
+  }
+
+  // Return false allows getSameOpcode to find an alternate instruction.
+  // Directly setting the mask will destroy the mask state, preventing us from
+  // determining which instruction the MainOp should convert to.
+  bool trySet(MaskType X) {
+    if (Mask & X) {
+      Mask &= X;
+      return true;
+    }
+    return false;
+  }
+
+public:
+  InterchangeableBinOp(Instruction *MainOp) : MainOp(MainOp) {
+    assert(is_sorted(SupportedOp) && "SupportedOp is not sorted.");
+  }
+  bool add(Instruction *I) {
+    SeenBefore |= opcodeToMask(I->getOpcode());
+    return trySet(getInterchangeableMask(I));
+  }
+  bool contain(Instruction *I) const {
+    return Mask & getInterchangeableMask(I);
+  }
+  unsigned getOpcode() const {
+    MaskType Candidate = Mask & SeenBefore;
+    if (Candidate & MainOpBIT)
+      return MainOp->getOpcode();
+    if (Candidate & ShlBIT)
+      return Instruction::Shl;
+    if (Candidate & AShrBIT)
+      return Instruction::AShr;
+    if (Candidate & MulBIT)
+      return Instruction::Mul;
+    if (Candidate & AddBIT)
+      return Instruction::Add;
+    if (Candidate & SubBIT)
+      return Instruction::Sub;
+    if (Candidate & AndBIT)
+      return Instruction::And;
+    if (Candidate & OrBIT)
+      return Instruction::Or;
+    if (Candidate & XorBIT)
+      return Instruction::Xor;
+    llvm_unreachable("Cannot find interchangeable instruction.");
+  }
+  SmallVector<Value *> getOperand(Instruction *I) const {
+    unsigned ToOpcode = I->getOpcode();
+    unsigned FromOpcode = MainOp->getOpcode();
+    if (FromOpcode == ToOpcode)
+      return SmallVector<Value *>(MainOp->operands());
+    assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
+    auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
+    const APInt &FromCIValue = CI->getValue();
+    unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
+    APInt ToCIValue;
+    switch (FromOpcode) {
+    case Instruction::Shl:
+      if (ToOpcode == Instruction::Mul) {
+        ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
+                                        FromCIValue.getZExtValue());
+      } else {
+        assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+        ToCIValue = ToOpcode == Instruction::And
+                        ? APInt::getAllOnes(FromCIValueBitWidth)
+                        : APInt::getZero(FromCIValueBitWidth);
+      }
+      break;
+    case Instruction::Mul:
+      assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
+      if (ToOpcode == Instruction::Shl) {
+        ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
+      } else {
+        assert(FromCIValue.isOne() && "Cannot convert the instruction.");
+        ToCIValue = ToOpcode == Instruction::And
+                        ? APInt::getAllOnes(FromCIValueBitWidth)
+                        : APInt::getZero(FromCIValueBitWidth);
+      }
+      break;
+    case Instruction::Add:
+    case Instruction::Sub:
+      if (FromCIValue.isZero()) {
+        ToCIValue = APInt::getZero(FromCIValueBitWidth);
+      } else {
+        assert(is_contained({Instruction::Add, Instruction::Sub}, ToOpcode) &&
+               "Cannot convert the instruction.");
+        ToCIValue = FromCIValue;
+        ToCIValue.negate();
+      }
+      break;
+    case Instruction::And:
+      assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
+      ToCIValue = ToOpcode == Instruction::Mul
+                      ? APInt::getOneBitSet(FromCIValueBitWidth, 0)
+                      : APInt::getZero(FromCIValueBitWidth);
+      break;
+    default:
+      ToCIValue = APInt::getZero(FromCIValueBitWidth);
+      break;
+    }
+    Value *LHS = MainOp->getOperand(1 - Pos);
+    Constant *RHS =
+        ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
+    if (Pos == 1)
+      return SmallVector<Value *>({LHS, RHS});
+    return SmallVector<Value *>({RHS, LHS});
+  }
+};
+
 /// Main data required for vectorization of instructions.
 class InstructionsState {
   /// The main/alternate instruction. MainOp is also VL0.
   Instruction *MainOp = nullptr;
   Instruction *AltOp = nullptr;
+  // Only BinaryOperator will activate this.
+  std::optional<InterchangeableBinOp> MainOpConverter;
+  std::optional<InterchangeableBinOp> AltOpConverter;
----------------
HanKuanChen wrote:

Revert ad7bec92d18a7bc908d6717bdb49b69843ab323c and ddcd456f749f1c4b12f7b96aa7c52197c0ac201a to remove the data members.

https://github.com/llvm/llvm-project/pull/127450


More information about the llvm-commits mailing list