[llvm] [SLP] Make getSameOpcode support interchangeable instructions. (PR #135797)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 15 08:52:21 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>

We use the term "interchangeable instructions" to refer to different
operators that have the same meaning (e.g., `add x, 0` is equivalent to
`mul x, 1`).
Non-constant values are not supported, as they may incur high costs with
little benefit.

---

Patch is 72.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135797.diff


26 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+411-51) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/vec3-base.ll (+4-4) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll (+3-4) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/vec3-base.ll (+4-4) 
- (added) llvm/test/Transforms/SLPVectorizer/X86/BinOpSameOpcodeHelper.ll (+36) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/barriercall.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/bottom-to-top-reorder.ll (+3-8) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/buildvector-postpone-for-dependency.ll (+3-5) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/bv-shuffle-mask.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/extract-scalar-from-undef.ll (+12-16) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/extractcost.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/gathered-delayed-nodes-with-reused-user.ll (+16-18) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/multi-extracts-bv-combined.ll (+2-4) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/non-scheduled-inst-reused-as-last-inst.ll (+20-24) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/propagate_ir_flags.ll (+3-9) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/reduced-val-vectorized-in-transform.ll (+3-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/shuffle-mask-emission.ll (+3-5) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/vec3-base.ll (+12-7) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/vect_copyable_in_binops.ll (+2-6) 
- (modified) llvm/test/Transforms/SLPVectorizer/alternate-opcode-sindle-bv.ll (+23-12) 
- (added) llvm/test/Transforms/SLPVectorizer/bbi-106161.ll (+19) 
- (added) llvm/test/Transforms/SLPVectorizer/isOpcodeOrAlt.ll (+61) 
- (modified) llvm/test/Transforms/SLPVectorizer/resized-alt-shuffle-after-minbw.ll (+1-3) 
- (modified) llvm/test/Transforms/SLPVectorizer/shuffle-mask-resized.ll (+1-3) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index cc775e4b260dc..253933a2438cd 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -599,6 +599,28 @@ static std::optional<unsigned> getElementIndex(const Value *Inst,
   return Index;
 }
 
+/// \returns true if all of the values in \p VL use the same opcode.
+/// For comparison instructions, also checks if predicates match.
+/// PoisonValues are considered matching.
+/// Interchangeable instructions are not considered.
+static bool allSameOpcode(ArrayRef<Value *> VL) {
+  auto *It = find_if(VL, IsaPred<Instruction>);
+  if (It == VL.end())
+    return true;
+  Instruction *MainOp = cast<Instruction>(*It);
+  unsigned Opcode = MainOp->getOpcode();
+  bool IsCmpOp = isa<CmpInst>(MainOp);
+  CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
+                                        : CmpInst::BAD_ICMP_PREDICATE;
+  return std::all_of(It, VL.end(), [&](Value *V) {
+    if (auto *CI = dyn_cast<CmpInst>(V))
+      return BasePred == CI->getPredicate();
+    if (auto *I = dyn_cast<Instruction>(V))
+      return I->getOpcode() == Opcode;
+    return isa<PoisonValue>(V);
+  });
+}
+
 namespace {
 /// Specifies the way the mask should be analyzed for undefs/poisonous elements
 /// in the shuffle mask.
@@ -814,6 +836,272 @@ static std::optional<unsigned> getExtractIndex(const Instruction *E) {
 }
 
 namespace {
+/// \returns true if \p Opcode is allowed as part of the main/alternate
+/// instruction for SLP vectorization.
+///
+/// Example of unsupported opcode is SDIV that can potentially cause UB if the
+/// "shuffled out" lane would result in division by zero.
+bool isValidForAlternation(unsigned Opcode) {
+  return !Instruction::isIntDivRem(Opcode);
+}
+
+/// Helper class that determines VL can use the same opcode.
+/// Alternate instruction is supported. In addition, it supports interchangeable
+/// instruction. An interchangeable instruction is an instruction that can be
+/// converted to another instruction with same semantics. For example, x << 1 is
+/// equal to x * 2. x * 1 is equal to x | 0.
+class BinOpSameOpcodeHelper {
+  using MaskType = std::uint_fast16_t;
+  /// Sort SupportedOp because it is used by binary_search.
+  constexpr static std::initializer_list<unsigned> SupportedOp = {
+      Instruction::Add,  Instruction::Sub, Instruction::Mul, Instruction::Shl,
+      Instruction::AShr, Instruction::And, Instruction::Or,  Instruction::Xor};
+  enum : MaskType {
+    ShlBIT = 0b1,
+    AShrBIT = 0b10,
+    MulBIT = 0b100,
+    AddBIT = 0b1000,
+    SubBIT = 0b10000,
+    AndBIT = 0b100000,
+    OrBIT = 0b1000000,
+    XorBIT = 0b10000000,
+    MainOpBIT = 0b100000000,
+    LLVM_MARK_AS_BITMASK_ENUM(MainOpBIT)
+  };
+  /// Return a non-nullptr if either operand of I is a ConstantInt.
+  /// The second return value represents the operand position. We check the
+  /// right-hand side first (1). If the right hand side is not a ConstantInt and
+  /// the instruction is neither Sub, Shl, nor AShr, we then check the left hand
+  /// side (0).
+  static std::pair<ConstantInt *, unsigned>
+  isBinOpWithConstantInt(const Instruction *I) {
+    unsigned Opcode = I->getOpcode();
+    assert(binary_search(SupportedOp, Opcode) && "Unsupported opcode.");
+    (void)SupportedOp;
+    auto *BinOp = cast<BinaryOperator>(I);
+    if (auto *CI = dyn_cast<ConstantInt>(BinOp->getOperand(1)))
+      return {CI, 1};
+    if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
+        Opcode == Instruction::AShr)
+      return {nullptr, 0};
+    if (auto *CI = dyn_cast<ConstantInt>(BinOp->getOperand(0)))
+      return {CI, 0};
+    return {nullptr, 0};
+  }
+  struct InterchangeableInfo {
+    const Instruction *I = nullptr;
+    /// The bit it sets represents whether MainOp can be converted to.
+    MaskType Mask = MainOpBIT | XorBIT | OrBIT | AndBIT | SubBIT | AddBIT |
+                    MulBIT | AShrBIT | ShlBIT;
+    /// We cannot create an interchangeable instruction that does not exist in
+    /// VL. For example, VL [x + 0, y * 1] can be converted to [x << 0, y << 0],
+    /// but << does not exist in VL. In the end, we convert VL to [x * 1, y *
+    /// 1]. SeenBefore is used to know what operations have been seen before.
+    MaskType SeenBefore = 0;
+    InterchangeableInfo(const Instruction *I) : I(I) {}
+    /// Return false allows BinOpSameOpcodeHelper to find an alternate
+    /// instruction. Directly setting the mask will destroy the mask state,
+    /// preventing us from determining which instruction it should convert to.
+    bool trySet(MaskType OpcodeInMaskForm, MaskType InterchangeableMask) {
+      if (Mask & InterchangeableMask) {
+        SeenBefore |= OpcodeInMaskForm;
+        Mask &= InterchangeableMask;
+        return true;
+      }
+      return false;
+    }
+    bool equal(unsigned Opcode) {
+      if (Opcode == I->getOpcode())
+        return trySet(MainOpBIT, MainOpBIT);
+      return false;
+    }
+    unsigned getOpcode() const {
+      MaskType Candidate = Mask & SeenBefore;
+      if (Candidate & MainOpBIT)
+        return I->getOpcode();
+      if (Candidate & ShlBIT)
+        return Instruction::Shl;
+      if (Candidate & AShrBIT)
+        return Instruction::AShr;
+      if (Candidate & MulBIT)
+        return Instruction::Mul;
+      if (Candidate & AddBIT)
+        return Instruction::Add;
+      if (Candidate & SubBIT)
+        return Instruction::Sub;
+      if (Candidate & AndBIT)
+        return Instruction::And;
+      if (Candidate & OrBIT)
+        return Instruction::Or;
+      if (Candidate & XorBIT)
+        return Instruction::Xor;
+      llvm_unreachable("Cannot find interchangeable instruction.");
+    }
+    SmallVector<Value *> getOperand(const Instruction *To) const {
+      unsigned ToOpcode = To->getOpcode();
+      unsigned FromOpcode = I->getOpcode();
+      if (FromOpcode == ToOpcode)
+        return SmallVector<Value *>(I->operands());
+      assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
+      auto [CI, Pos] = isBinOpWithConstantInt(I);
+      const APInt &FromCIValue = CI->getValue();
+      unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
+      APInt ToCIValue;
+      switch (FromOpcode) {
+      case Instruction::Shl:
+        if (ToOpcode == Instruction::Mul) {
+          ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
+                                          FromCIValue.getZExtValue());
+        } else {
+          assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+          ToCIValue = ToOpcode == Instruction::And
+                          ? APInt::getAllOnes(FromCIValueBitWidth)
+                          : APInt::getZero(FromCIValueBitWidth);
+        }
+        break;
+      case Instruction::Mul:
+        assert(FromCIValue.isPowerOf2() && "Cannot convert the instruction.");
+        if (ToOpcode == Instruction::Shl) {
+          ToCIValue = APInt(FromCIValueBitWidth, FromCIValue.logBase2());
+        } else {
+          assert(FromCIValue.isOne() && "Cannot convert the instruction.");
+          ToCIValue = ToOpcode == Instruction::And
+                          ? APInt::getAllOnes(FromCIValueBitWidth)
+                          : APInt::getZero(FromCIValueBitWidth);
+        }
+        break;
+      case Instruction::Add:
+      case Instruction::Sub:
+        if (FromCIValue.isZero()) {
+          ToCIValue = APInt::getZero(FromCIValueBitWidth);
+        } else {
+          assert(is_contained({Instruction::Add, Instruction::Sub}, ToOpcode) &&
+                 "Cannot convert the instruction.");
+          ToCIValue = FromCIValue;
+          ToCIValue.negate();
+        }
+        break;
+      case Instruction::And:
+        assert(FromCIValue.isAllOnes() && "Cannot convert the instruction.");
+        ToCIValue = ToOpcode == Instruction::Mul
+                        ? APInt::getOneBitSet(FromCIValueBitWidth, 0)
+                        : APInt::getZero(FromCIValueBitWidth);
+        break;
+      default:
+        assert(FromCIValue.isZero() && "Cannot convert the instruction.");
+        ToCIValue = APInt::getZero(FromCIValueBitWidth);
+        break;
+      }
+      Value *LHS = I->getOperand(1 - Pos);
+      Constant *RHS =
+          ConstantInt::get(I->getOperand(Pos)->getType(), ToCIValue);
+      if (Pos == 1)
+        return SmallVector<Value *>({LHS, RHS});
+      return SmallVector<Value *>({RHS, LHS});
+    }
+  };
+  InterchangeableInfo MainOp;
+  InterchangeableInfo AltOp;
+  bool isValidForAlternation(const Instruction *I) const {
+    return ::isValidForAlternation(MainOp.I->getOpcode()) &&
+           ::isValidForAlternation(I->getOpcode());
+  }
+  bool initializeAltOp(const Instruction *I) {
+    if (AltOp.I)
+      return true;
+    if (!isValidForAlternation(I))
+      return false;
+    AltOp.I = I;
+    return true;
+  }
+
+public:
+  BinOpSameOpcodeHelper(const Instruction *MainOp,
+                        const Instruction *AltOp = nullptr)
+      : MainOp(MainOp), AltOp(AltOp) {
+    assert(is_sorted(SupportedOp) && "SupportedOp is not sorted.");
+  }
+  bool add(const Instruction *I) {
+    assert(isa<BinaryOperator>(I) &&
+           "BinOpSameOpcodeHelper only accepts BinaryOperator.");
+    unsigned Opcode = I->getOpcode();
+    MaskType OpcodeInMaskForm;
+    // Prefer Shl, AShr, Mul, Add, Sub, And, Or and Xor over MainOp.
+    switch (Opcode) {
+    case Instruction::Shl:
+      OpcodeInMaskForm = ShlBIT;
+      break;
+    case Instruction::AShr:
+      OpcodeInMaskForm = AShrBIT;
+      break;
+    case Instruction::Mul:
+      OpcodeInMaskForm = MulBIT;
+      break;
+    case Instruction::Add:
+      OpcodeInMaskForm = AddBIT;
+      break;
+    case Instruction::Sub:
+      OpcodeInMaskForm = SubBIT;
+      break;
+    case Instruction::And:
+      OpcodeInMaskForm = AndBIT;
+      break;
+    case Instruction::Or:
+      OpcodeInMaskForm = OrBIT;
+      break;
+    case Instruction::Xor:
+      OpcodeInMaskForm = XorBIT;
+      break;
+    default:
+      return MainOp.equal(Opcode) ||
+             (initializeAltOp(I) && AltOp.equal(Opcode));
+    }
+    MaskType InterchangeableMask = OpcodeInMaskForm;
+    ConstantInt *CI = isBinOpWithConstantInt(I).first;
+    if (CI) {
+      constexpr MaskType CanBeAll =
+          XorBIT | OrBIT | AndBIT | SubBIT | AddBIT | MulBIT | AShrBIT | ShlBIT;
+      const APInt &CIValue = CI->getValue();
+      switch (Opcode) {
+      case Instruction::Shl:
+        if (CIValue.ult(CIValue.getBitWidth()))
+          InterchangeableMask = CIValue.isZero() ? CanBeAll : MulBIT | ShlBIT;
+        break;
+      case Instruction::Mul:
+        if (CIValue.isOne()) {
+          InterchangeableMask = CanBeAll;
+          break;
+        }
+        if (CIValue.isPowerOf2())
+          InterchangeableMask = MulBIT | ShlBIT;
+        break;
+      case Instruction::Add:
+      case Instruction::Sub:
+        InterchangeableMask = CIValue.isZero() ? CanBeAll : SubBIT | AddBIT;
+        break;
+      case Instruction::And:
+        if (CIValue.isAllOnes())
+          InterchangeableMask = CanBeAll;
+        break;
+      default:
+        if (CIValue.isZero())
+          InterchangeableMask = CanBeAll;
+        break;
+      }
+    }
+    return MainOp.trySet(OpcodeInMaskForm, InterchangeableMask) ||
+           (initializeAltOp(I) &&
+            AltOp.trySet(OpcodeInMaskForm, InterchangeableMask));
+  }
+  unsigned getMainOpcode() const { return MainOp.getOpcode(); }
+  bool hasAltOp() const { return AltOp.I; }
+  unsigned getAltOpcode() const {
+    return hasAltOp() ? AltOp.getOpcode() : getMainOpcode();
+  }
+  SmallVector<Value *> getOperand(const Instruction *I) const {
+    return MainOp.getOperand(I);
+  }
+};
 
 /// Main data required for vectorization of instructions.
 class InstructionsState {
@@ -861,9 +1149,27 @@ class InstructionsState {
   /// Some of the instructions in the list have alternate opcodes.
   bool isAltShuffle() const { return getMainOp() != getAltOp(); }
 
-  bool isOpcodeOrAlt(Instruction *I) const {
-    unsigned CheckedOpcode = I->getOpcode();
-    return getOpcode() == CheckedOpcode || getAltOpcode() == CheckedOpcode;
+  /// Checks if the instruction matches either the main or alternate opcode.
+  /// \returns
+  /// - MainOp if \param I matches MainOp's opcode directly or can be converted
+  /// to it
+  /// - AltOp if \param I matches AltOp's opcode directly or can be converted to
+  /// it
+  /// - nullptr if \param I cannot be matched or converted to either opcode
+  Instruction *getMatchingMainOpOrAltOp(Instruction *I) const {
+    assert(MainOp && "MainOp cannot be nullptr.");
+    if (I->getOpcode() == MainOp->getOpcode())
+      return MainOp;
+    // Prefer AltOp instead of interchangeable instruction of MainOp.
+    assert(AltOp && "AltOp cannot be nullptr.");
+    if (I->getOpcode() == AltOp->getOpcode())
+      return AltOp;
+    if (!I->isBinaryOp())
+      return nullptr;
+    BinOpSameOpcodeHelper Converter(MainOp);
+    if (Converter.add(I) && Converter.add(MainOp) && !Converter.hasAltOp())
+      return MainOp;
+    return AltOp;
   }
 
   /// Checks if main/alt instructions are shift operations.
@@ -913,23 +1219,41 @@ class InstructionsState {
   static InstructionsState invalid() { return {nullptr, nullptr}; }
 };
 
-} // end anonymous namespace
-
-/// \returns true if \p Opcode is allowed as part of the main/alternate
-/// instruction for SLP vectorization.
-///
-/// Example of unsupported opcode is SDIV that can potentially cause UB if the
-/// "shuffled out" lane would result in division by zero.
-static bool isValidForAlternation(unsigned Opcode) {
-  if (Instruction::isIntDivRem(Opcode))
-    return false;
-
-  return true;
+std::pair<Instruction *, SmallVector<Value *>>
+convertTo(Instruction *I, const InstructionsState &S) {
+  Instruction *SelectedOp = S.getMatchingMainOpOrAltOp(I);
+  assert(SelectedOp && "Cannot convert the instruction.");
+  if (I->isBinaryOp()) {
+    BinOpSameOpcodeHelper Converter(I);
+    return std::make_pair(SelectedOp, Converter.getOperand(SelectedOp));
+  }
+  return std::make_pair(SelectedOp, SmallVector<Value *>(I->operands()));
 }
 
+} // end anonymous namespace
+
 static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                                        const TargetLibraryInfo &TLI);
 
+/// Find an instruction with a specific opcode in VL.
+/// \param VL Array of values to search through. Must contain only Instructions
+///           and PoisonValues.
+/// \param Opcode The instruction opcode to search for
+/// \returns
+/// - The first instruction found with matching opcode
+/// - nullptr if no matching instruction is found
+Instruction *findInstructionWithOpcode(ArrayRef<Value *> VL, unsigned Opcode) {
+  for (Value *V : VL) {
+    if (isa<PoisonValue>(V))
+      continue;
+    assert(isa<Instruction>(V) && "Only accepts PoisonValue and Instruction.");
+    auto *Inst = cast<Instruction>(V);
+    if (Inst->getOpcode() == Opcode)
+      return Inst;
+  }
+  return nullptr;
+}
+
 /// Checks if the provided operands of 2 cmp instructions are compatible, i.e.
 /// compatible instructions or constants, or just some other regular values.
 static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0,
@@ -993,6 +1317,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   unsigned Opcode = MainOp->getOpcode();
   unsigned AltOpcode = Opcode;
 
+  BinOpSameOpcodeHelper BinOpHelper(MainOp);
   bool SwappedPredsCompatible = IsCmpOp && [&]() {
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
     UniquePreds.insert(BasePred);
@@ -1039,14 +1364,8 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       return InstructionsState::invalid();
     unsigned InstOpcode = I->getOpcode();
     if (IsBinOp && isa<BinaryOperator>(I)) {
-      if (InstOpcode == Opcode || InstOpcode == AltOpcode)
+      if (BinOpHelper.add(I))
         continue;
-      if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
-          isValidForAlternation(Opcode)) {
-        AltOpcode = InstOpcode;
-        AltOp = I;
-        continue;
-      }
     } else if (IsCastOp && isa<CastInst>(I)) {
       Value *Op0 = MainOp->getOperand(0);
       Type *Ty0 = Op0->getType();
@@ -1147,7 +1466,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
   }
 
-  return InstructionsState(MainOp, AltOp);
+  if (IsBinOp) {
+    MainOp = findInstructionWithOpcode(VL, BinOpHelper.getMainOpcode());
+    assert(MainOp && "Cannot find MainOp with Opcode from BinOpHelper.");
+    AltOp = findInstructionWithOpcode(VL, BinOpHelper.getAltOpcode());
+    assert(MainOp && "Cannot find AltOp with Opcode from BinOpHelper.");
+  }
+  assert((MainOp == AltOp || !allSameOpcode(VL)) &&
+         "Incorrect implementation of allSameOpcode.");
+  InstructionsState S(MainOp, AltOp);
+  assert(all_of(VL,
+                [&](Value *V) {
+                  return isa<PoisonValue>(V) ||
+                         S.getMatchingMainOpOrAltOp(cast<Instruction>(V));
+                }) &&
+         "Invalid InstructionsState.");
+  return S;
 }
 
 /// \returns true if all of the values in \p VL have the same type or false
@@ -2560,11 +2894,11 @@ class BoUpSLP {
         // Since operand reordering is performed on groups of commutative
         // operations or alternating sequences (e.g., +, -), we can safely tell
         // the inverse operations by checking commutativity.
-        bool IsInverseOperation = !isCommutative(cast<Instruction>(V));
+        auto [SelectedOp, Ops] = convertTo(cast<Instruction>(VL[Lane]), S);
+        bool IsInverseOperation = !isCommutative(SelectedOp);
         for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) {
           bool APO = (OpIdx == 0) ? false : IsInverseOperation;
-          OpsVec[OpIdx][Lane] = {cast<Instruction>(V)->getOperand(OpIdx), APO,
-                                 false};
+          OpsVec[OpIdx][Lane] = {Ops[OpIdx], APO, false};
         }
       }
     }
@@ -3542,14 +3876,16 @@ class BoUpSLP {
     /// Some of the instructions in the list have alternate opcodes.
     bool isAltShuffle() const { return S.isAltShuffle(); }
 
-    bool isOpcodeOrAlt(Instruction *I) const { return S.isOpcodeOrAlt(I); }
+    Instruction *getMatchingMainOpOrAltOp(Instruction *I) const {
+      return S.getMatchingMainOpOrAltOp(I);
+    }
 
     /// Chooses the correct key for scheduling data. If \p Op has the same (or
     /// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is
     /// \p OpValue.
     Value *isOneOf(Value *Op) const {
       auto *I = dyn_cast<Instruction>(Op);
-      if (I && isOpcodeOrAlt(I))
+      if (I && getMatchingMainOpOrAltOp(I))
         return Op;
       return S.getMainOp();
     }
@@ -8428,11 +8764,15 @@ static std::pair<size_t, size_t> generateKeySubkey(
   return std::make_pair(Key, SubKey);
 }
 
+/// Checks if the specified instruction \p I is an main operation for the given
+/// \p MainOp and \p AltOp instructions.
+static bool isMainInstruction(Instruction *I, Instruction *MainOp,
+                              Instruction *AltOp, const TargetLibraryInfo &TLI);
+
 /// Checks if the specified instruction \p I is an alternate operation for
 /// the given \p MainOp and \p AltOp instructions.
-static bool isAlternateInstruction(const Instruction *I,
-                                   const Instruction *MainOp,
-                                   const Instruction *AltOp,
+static bool isAlternateInstruction(Instruction *I, Instruction *MainOp,
+                                   Instruction *AltOp,
                                    const TargetLibraryInfo &TLI);
 
 bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
@@ -9245,7 +9585,8 @@ bool BoUpSLP::canBuildSplitNode(ArrayRef<Value *> VL,
       continue;
     }
     if ((LocalState.getAltOpcode() != LocalState.getOpcode() &&
-         I->getOpcode() == LocalState.getOpcode()) ||
+         isMainInstruction(I, LocalState.getMainOp(), LocalState.getAltOp(),
+                           *TLI)) ||
         (LocalState.getAltOpcode() == LocalState.getOpcode() &&
          !isAlternateInstruction(I, Local...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/135797


More information about the llvm-commits mailing list