[llvm] 853c43d - [TTI] NFC: Port TLI.shouldSinkOperands to TTI (#110564)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 14:30:14 PDT 2024


Author: Jeffrey Byrnes
Date: 2024-10-09T14:30:09-07:00
New Revision: 853c43d04a378c379e49db552e856f02a5ad9216

URL: https://github.com/llvm/llvm-project/commit/853c43d04a378c379e49db552e856f02a5ad9216
DIFF: https://github.com/llvm/llvm-project/commit/853c43d04a378c379e49db552e856f02a5ad9216.diff

LOG: [TTI] NFC: Port TLI.shouldSinkOperands to TTI (#110564)

Porting to TTI provides direct access to the instruction cost model,
which can enable instruction cost based sinking without introducing code
duplication.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/lib/Target/ARM/ARMISelLowering.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 2befacea4df866..5c5da5e06c1bff 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1747,6 +1747,21 @@ class TargetTransformInfo {
   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                              Align Alignment) const;
 
+  /// Return true if sinking I's operands to the same basic block as I is
+  /// profitable, e.g. because the operands can be folded into a target
+  /// instruction during instruction selection. After calling the function
+  /// \p Ops contains the Uses to sink ordered by dominance (dominating users
+  /// come first).
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
+
+  /// Return true if it's significantly cheaper to shift a vector by a uniform
+  /// scalar than by an amount which will vary across each lane. On x86 before
+  /// AVX2 for example, there is a "psllw" instruction for the former case, but
+  /// no simple instruction for a general "a << b" operation on vectors.
+  /// This should also apply to lowering for vector funnel shifts (rotates).
+  bool isVectorShiftByScalarCheap(Type *Ty) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2187,6 +2202,11 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool
+  isProfitableToSinkOperands(Instruction *I,
+                             SmallVectorImpl<Use *> &OpsToSink) const = 0;
+
+  virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2963,6 +2983,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const override {
+    return Impl.isProfitableToSinkOperands(I, Ops);
+  };
+
+  bool isVectorShiftByScalarCheap(Type *Ty) const override {
+    return Impl.isVectorShiftByScalarCheap(Ty);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..6d3ce93acbe451 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -977,6 +977,13 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const {
+    return false;
+  }
+
+  bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(

diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 4c76592c42e1eb..5ab31a687ec5e9 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -2860,15 +2860,6 @@ class TargetLoweringBase {
     return Value == 0;
   }
 
-  /// Return true if it's significantly cheaper to shift a vector by a uniform
-  /// scalar than by an amount which will vary across each lane. On x86 before
-  /// AVX2 for example, there is a "psllw" instruction for the former case, but
-  /// no simple instruction for a general "a << b" operation on vectors.
-  /// This should also apply to lowering for vector funnel shifts (rotates).
-  virtual bool isVectorShiftByScalarCheap(Type *Ty) const {
-    return false;
-  }
-
   /// Given a shuffle vector SVI representing a vector splat, return a new
   /// scalar type of size equal to SVI's scalar type if the new type is more
   /// profitable. Returns nullptr otherwise. For example under MVE float splats
@@ -3085,16 +3076,6 @@ class TargetLoweringBase {
   /// a larger type.
   virtual bool signExtendConstant(const ConstantInt *C) const { return false; }
 
-  /// Return true if sinking I's operands to the same basic block as I is
-  /// profitable, e.g. because the operands can be folded into a target
-  /// instruction during instruction selection. After calling the function
-  /// \p Ops contains the Uses to sink ordered by dominance (dominating users
-  /// come first).
-  virtual bool shouldSinkOperands(Instruction *I,
-                                  SmallVectorImpl<Use *> &Ops) const {
-    return false;
-  }
-
   /// Try to optimize extending or truncating conversion instructions (like
   /// zext, trunc, fptoui, uitofp) for the target.
   virtual bool

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b612a3331e5737..3dc29fc7cd77b1 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1359,6 +1359,15 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::isProfitableToSinkOperands(
+    Instruction *I, SmallVectorImpl<Use *> &OpsToSink) const {
+  return TTIImpl->isProfitableToSinkOperands(I, OpsToSink);
+}
+
+bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
+  return TTIImpl->isVectorShiftByScalarCheap(Ty);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 631cc26d6022fe..3e09fbad6ab198 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -7274,7 +7274,7 @@ bool CodeGenPrepare::optimizeShiftInst(BinaryOperator *Shift) {
   // We can't do this effectively in SDAG because we may not be able to
   // determine if the select operands are splats from within a basic block.
   Type *Ty = Shift->getType();
-  if (!Ty->isVectorTy() || !TLI->isVectorShiftByScalarCheap(Ty))
+  if (!Ty->isVectorTy() || !TTI->isVectorShiftByScalarCheap(Ty))
     return false;
   Value *Cond, *TVal, *FVal;
   if (!match(Shift->getOperand(1),
@@ -7309,7 +7309,7 @@ bool CodeGenPrepare::optimizeFunnelShift(IntrinsicInst *Fsh) {
   // We can't do this effectively in SDAG because we may not be able to
   // determine if the select operands are splats from within a basic block.
   Type *Ty = Fsh->getType();
-  if (!Ty->isVectorTy() || !TLI->isVectorShiftByScalarCheap(Ty))
+  if (!Ty->isVectorTy() || !TTI->isVectorShiftByScalarCheap(Ty))
     return false;
   Value *Cond, *TVal, *FVal;
   if (!match(Fsh->getOperand(2),
@@ -7566,7 +7566,7 @@ bool CodeGenPrepare::tryToSinkFreeOperands(Instruction *I) {
   // If the operands of I can be folded into a target instruction together with
   // I, duplicate and sink them.
   SmallVector<Use *, 4> OpsToSink;
-  if (!TLI->shouldSinkOperands(I, OpsToSink))
+  if (!TTI->isProfitableToSinkOperands(I, OpsToSink))
     return false;
 
   // OpsToSink can contain multiple uses in a use chain (e.g.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 288fd3639e5eb7..381794caeb85be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16326,422 +16326,6 @@ bool AArch64TargetLowering::isExtFreeImpl(const Instruction *Ext) const {
   return true;
 }
 
-static bool isSplatShuffle(Value *V) {
-  if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
-    return all_equal(Shuf->getShuffleMask());
-  return false;
-}
-
-/// Check if both Op1 and Op2 are shufflevector extracts of either the lower
-/// or upper half of the vector elements.
-static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
-                                     bool AllowSplat = false) {
-  // Scalable types can't be extract shuffle vectors.
-  if (Op1->getType()->isScalableTy() || Op2->getType()->isScalableTy())
-    return false;
-
-  auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
-    auto *FullTy = FullV->getType();
-    auto *HalfTy = HalfV->getType();
-    return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
-           2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
-  };
-
-  auto extractHalf = [](Value *FullV, Value *HalfV) {
-    auto *FullVT = cast<FixedVectorType>(FullV->getType());
-    auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
-    return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
-  };
-
-  ArrayRef<int> M1, M2;
-  Value *S1Op1 = nullptr, *S2Op1 = nullptr;
-  if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
-      !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
-    return false;
-
-  // If we allow splats, set S1Op1/S2Op1 to nullptr for the relavant arg so that
-  // it is not checked as an extract below.
-  if (AllowSplat && isSplatShuffle(Op1))
-    S1Op1 = nullptr;
-  if (AllowSplat && isSplatShuffle(Op2))
-    S2Op1 = nullptr;
-
-  // Check that the operands are half as wide as the result and we extract
-  // half of the elements of the input vectors.
-  if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
-      (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
-    return false;
-
-  // Check the mask extracts either the lower or upper half of vector
-  // elements.
-  int M1Start = 0;
-  int M2Start = 0;
-  int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
-  if ((S1Op1 &&
-       !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
-      (S2Op1 &&
-       !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
-    return false;
-
-  if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
-      (M2Start != 0 && M2Start != (NumElements / 2)))
-    return false;
-  if (S1Op1 && S2Op1 && M1Start != M2Start)
-    return false;
-
-  return true;
-}
-
-/// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
-/// of the vector elements.
-static bool areExtractExts(Value *Ext1, Value *Ext2) {
-  auto areExtDoubled = [](Instruction *Ext) {
-    return Ext->getType()->getScalarSizeInBits() ==
-           2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
-  };
-
-  if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
-      !match(Ext2, m_ZExtOrSExt(m_Value())) ||
-      !areExtDoubled(cast<Instruction>(Ext1)) ||
-      !areExtDoubled(cast<Instruction>(Ext2)))
-    return false;
-
-  return true;
-}
-
-/// Check if Op could be used with vmull_high_p64 intrinsic.
-static bool isOperandOfVmullHighP64(Value *Op) {
-  Value *VectorOperand = nullptr;
-  ConstantInt *ElementIndex = nullptr;
-  return match(Op, m_ExtractElt(m_Value(VectorOperand),
-                                m_ConstantInt(ElementIndex))) &&
-         ElementIndex->getValue() == 1 &&
-         isa<FixedVectorType>(VectorOperand->getType()) &&
-         cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
-}
-
-/// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
-static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
-  return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
-}
-
-static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
-  // Restrict ourselves to the form CodeGenPrepare typically constructs.
-  auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
-  if (!GEP || GEP->getNumOperands() != 2)
-    return false;
-
-  Value *Base = GEP->getOperand(0);
-  Value *Offsets = GEP->getOperand(1);
-
-  // We only care about scalar_base+vector_offsets.
-  if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
-    return false;
-
-  // Sink extends that would allow us to use 32-bit offset vectors.
-  if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
-    auto *OffsetsInst = cast<Instruction>(Offsets);
-    if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
-        OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
-      Ops.push_back(&GEP->getOperandUse(1));
-  }
-
-  // Sink the GEP.
-  return true;
-}
-
-/// We want to sink following cases:
-/// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
-/// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
-static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
-  if (match(Op, m_VScale()))
-    return true;
-  if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
-      match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
-    Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
-    return true;
-  }
-  if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
-      match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
-    Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
-    Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
-    Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
-    return true;
-  }
-  return false;
-}
-
-/// Check if sinking \p I's operands to I's basic block is profitable, because
-/// the operands can be folded into a target instruction, e.g.
-/// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
-bool AArch64TargetLowering::shouldSinkOperands(
-    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
-  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
-    switch (II->getIntrinsicID()) {
-    case Intrinsic::aarch64_neon_smull:
-    case Intrinsic::aarch64_neon_umull:
-      if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
-                                   /*AllowSplat=*/true)) {
-        Ops.push_back(&II->getOperandUse(0));
-        Ops.push_back(&II->getOperandUse(1));
-        return true;
-      }
-      [[fallthrough]];
-
-    case Intrinsic::fma:
-    case Intrinsic::fmuladd:
-      if (isa<VectorType>(I->getType()) &&
-          cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
-          !Subtarget->hasFullFP16())
-        return false;
-      [[fallthrough]];
-    case Intrinsic::aarch64_neon_sqdmull:
-    case Intrinsic::aarch64_neon_sqdmulh:
-    case Intrinsic::aarch64_neon_sqrdmulh:
-      // Sink splats for index lane variants
-      if (isSplatShuffle(II->getOperand(0)))
-        Ops.push_back(&II->getOperandUse(0));
-      if (isSplatShuffle(II->getOperand(1)))
-        Ops.push_back(&II->getOperandUse(1));
-      return !Ops.empty();
-    case Intrinsic::aarch64_neon_fmlal:
-    case Intrinsic::aarch64_neon_fmlal2:
-    case Intrinsic::aarch64_neon_fmlsl:
-    case Intrinsic::aarch64_neon_fmlsl2:
-      // Sink splats for index lane variants
-      if (isSplatShuffle(II->getOperand(1)))
-        Ops.push_back(&II->getOperandUse(1));
-      if (isSplatShuffle(II->getOperand(2)))
-        Ops.push_back(&II->getOperandUse(2));
-      return !Ops.empty();
-    case Intrinsic::aarch64_sve_ptest_first:
-    case Intrinsic::aarch64_sve_ptest_last:
-      if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
-        if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
-          Ops.push_back(&II->getOperandUse(0));
-      return !Ops.empty();
-    case Intrinsic::aarch64_sme_write_horiz:
-    case Intrinsic::aarch64_sme_write_vert:
-    case Intrinsic::aarch64_sme_writeq_horiz:
-    case Intrinsic::aarch64_sme_writeq_vert: {
-      auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
-      if (!Idx || Idx->getOpcode() != Instruction::Add)
-        return false;
-      Ops.push_back(&II->getOperandUse(1));
-      return true;
-    }
-    case Intrinsic::aarch64_sme_read_horiz:
-    case Intrinsic::aarch64_sme_read_vert:
-    case Intrinsic::aarch64_sme_readq_horiz:
-    case Intrinsic::aarch64_sme_readq_vert:
-    case Intrinsic::aarch64_sme_ld1b_vert:
-    case Intrinsic::aarch64_sme_ld1h_vert:
-    case Intrinsic::aarch64_sme_ld1w_vert:
-    case Intrinsic::aarch64_sme_ld1d_vert:
-    case Intrinsic::aarch64_sme_ld1q_vert:
-    case Intrinsic::aarch64_sme_st1b_vert:
-    case Intrinsic::aarch64_sme_st1h_vert:
-    case Intrinsic::aarch64_sme_st1w_vert:
-    case Intrinsic::aarch64_sme_st1d_vert:
-    case Intrinsic::aarch64_sme_st1q_vert:
-    case Intrinsic::aarch64_sme_ld1b_horiz:
-    case Intrinsic::aarch64_sme_ld1h_horiz:
-    case Intrinsic::aarch64_sme_ld1w_horiz:
-    case Intrinsic::aarch64_sme_ld1d_horiz:
-    case Intrinsic::aarch64_sme_ld1q_horiz:
-    case Intrinsic::aarch64_sme_st1b_horiz:
-    case Intrinsic::aarch64_sme_st1h_horiz:
-    case Intrinsic::aarch64_sme_st1w_horiz:
-    case Intrinsic::aarch64_sme_st1d_horiz:
-    case Intrinsic::aarch64_sme_st1q_horiz: {
-      auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
-      if (!Idx || Idx->getOpcode() != Instruction::Add)
-        return false;
-      Ops.push_back(&II->getOperandUse(3));
-      return true;
-    }
-    case Intrinsic::aarch64_neon_pmull:
-      if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
-        return false;
-      Ops.push_back(&II->getOperandUse(0));
-      Ops.push_back(&II->getOperandUse(1));
-      return true;
-    case Intrinsic::aarch64_neon_pmull64:
-      if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
-                                     II->getArgOperand(1)))
-        return false;
-      Ops.push_back(&II->getArgOperandUse(0));
-      Ops.push_back(&II->getArgOperandUse(1));
-      return true;
-    case Intrinsic::masked_gather:
-      if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
-        return false;
-      Ops.push_back(&II->getArgOperandUse(0));
-      return true;
-    case Intrinsic::masked_scatter:
-      if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
-        return false;
-      Ops.push_back(&II->getArgOperandUse(1));
-      return true;
-    default:
-      return false;
-    }
-  }
-
-  // Sink vscales closer to uses for better isel
-  switch (I->getOpcode()) {
-  case Instruction::GetElementPtr:
-  case Instruction::Add:
-  case Instruction::Sub:
-    for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
-      if (shouldSinkVScale(I->getOperand(Op), Ops)) {
-        Ops.push_back(&I->getOperandUse(Op));
-        return true;
-      }
-    }
-    break;
-  default:
-    break;
-  }
-
-  if (!I->getType()->isVectorTy())
-    return false;
-
-  switch (I->getOpcode()) {
-  case Instruction::Sub:
-  case Instruction::Add: {
-    if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
-      return false;
-
-    // If the exts' operands extract either the lower or upper elements, we
-    // can sink them too.
-    auto Ext1 = cast<Instruction>(I->getOperand(0));
-    auto Ext2 = cast<Instruction>(I->getOperand(1));
-    if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
-      Ops.push_back(&Ext1->getOperandUse(0));
-      Ops.push_back(&Ext2->getOperandUse(0));
-    }
-
-    Ops.push_back(&I->getOperandUse(0));
-    Ops.push_back(&I->getOperandUse(1));
-
-    return true;
-  }
-  case Instruction::Or: {
-    // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
-    // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
-    if (Subtarget->hasNEON()) {
-      Instruction *OtherAnd, *IA, *IB;
-      Value *MaskValue;
-      // MainAnd refers to And instruction that has 'Not' as one of its operands
-      if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
-                          m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
-                                           m_Instruction(IA)))))) {
-        if (match(OtherAnd,
-                  m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
-          Instruction *MainAnd = I->getOperand(0) == OtherAnd
-                                     ? cast<Instruction>(I->getOperand(1))
-                                     : cast<Instruction>(I->getOperand(0));
-
-          // Both Ands should be in same basic block as Or
-          if (I->getParent() != MainAnd->getParent() ||
-              I->getParent() != OtherAnd->getParent())
-            return false;
-
-          // Non-mask operands of both Ands should also be in same basic block
-          if (I->getParent() != IA->getParent() ||
-              I->getParent() != IB->getParent())
-            return false;
-
-          Ops.push_back(&MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
-          Ops.push_back(&I->getOperandUse(0));
-          Ops.push_back(&I->getOperandUse(1));
-
-          return true;
-        }
-      }
-    }
-
-    return false;
-  }
-  case Instruction::Mul: {
-    int NumZExts = 0, NumSExts = 0;
-    for (auto &Op : I->operands()) {
-      // Make sure we are not already sinking this operand
-      if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
-        continue;
-
-      if (match(&Op, m_SExt(m_Value()))) {
-        NumSExts++;
-        continue;
-      } else if (match(&Op, m_ZExt(m_Value()))) {
-        NumZExts++;
-        continue;
-      }
-
-      ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
-
-      // If the Shuffle is a splat and the operand is a zext/sext, sinking the
-      // operand and the s/zext can help create indexed s/umull. This is
-      // especially useful to prevent i64 mul being scalarized.
-      if (Shuffle && isSplatShuffle(Shuffle) &&
-          match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
-        Ops.push_back(&Shuffle->getOperandUse(0));
-        Ops.push_back(&Op);
-        if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
-          NumSExts++;
-        else
-          NumZExts++;
-        continue;
-      }
-
-      if (!Shuffle)
-        continue;
-
-      Value *ShuffleOperand = Shuffle->getOperand(0);
-      InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
-      if (!Insert)
-        continue;
-
-      Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
-      if (!OperandInstr)
-        continue;
-
-      ConstantInt *ElementConstant =
-          dyn_cast<ConstantInt>(Insert->getOperand(2));
-      // Check that the insertelement is inserting into element 0
-      if (!ElementConstant || !ElementConstant->isZero())
-        continue;
-
-      unsigned Opcode = OperandInstr->getOpcode();
-      if (Opcode == Instruction::SExt)
-        NumSExts++;
-      else if (Opcode == Instruction::ZExt)
-        NumZExts++;
-      else {
-        // If we find that the top bits are known 0, then we can sink and allow
-        // the backend to generate a umull.
-        unsigned Bitwidth = I->getType()->getScalarSizeInBits();
-        APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
-        const DataLayout &DL = I->getDataLayout();
-        if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
-          continue;
-        NumZExts++;
-      }
-
-      Ops.push_back(&Shuffle->getOperandUse(0));
-      Ops.push_back(&Op);
-    }
-
-    // Is it profitable to sink if we found two of the same type of extends.
-    return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
-  }
-  default:
-    return false;
-  }
-  return false;
-}
-
 static bool createTblShuffleMask(unsigned SrcWidth, unsigned DstWidth,
                                  unsigned NumElts, bool IsLittleEndian,
                                  SmallVectorImpl<int> &Mask) {

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1bae7562f459a5..035a802cd49b3c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -688,9 +688,6 @@ class AArch64TargetLowering : public TargetLowering {
   bool isZExtFree(EVT VT1, EVT VT2) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
 
-  bool shouldSinkOperands(Instruction *I,
-                          SmallVectorImpl<Use *> &Ops) const override;
-
   bool optimizeExtendOrTruncateConversion(
       Instruction *I, Loop *L, const TargetTransformInfo &TTI) const override;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 80d5168ae961ab..7b74bb2a03a642 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4672,3 +4672,420 @@ bool AArch64TTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
 
   return TargetTransformInfoImplBase::isLSRCostLess(C1, C2);
 }
+
+static bool isSplatShuffle(Value *V) {
+  if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
+    return all_equal(Shuf->getShuffleMask());
+  return false;
+}
+
+/// Check if both Op1 and Op2 are shufflevector extracts of either the lower
+/// or upper half of the vector elements.
+static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
+                                     bool AllowSplat = false) {
+  // Scalable types can't be extract shuffle vectors.
+  if (Op1->getType()->isScalableTy() || Op2->getType()->isScalableTy())
+    return false;
+
+  auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
+    auto *FullTy = FullV->getType();
+    auto *HalfTy = HalfV->getType();
+    return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
+           2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
+  };
+
+  auto extractHalf = [](Value *FullV, Value *HalfV) {
+    auto *FullVT = cast<FixedVectorType>(FullV->getType());
+    auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
+    return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
+  };
+
+  ArrayRef<int> M1, M2;
+  Value *S1Op1 = nullptr, *S2Op1 = nullptr;
+  if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
+      !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
+    return false;
+
+  // If we allow splats, set S1Op1/S2Op1 to nullptr for the relavant arg so that
+  // it is not checked as an extract below.
+  if (AllowSplat && isSplatShuffle(Op1))
+    S1Op1 = nullptr;
+  if (AllowSplat && isSplatShuffle(Op2))
+    S2Op1 = nullptr;
+
+  // Check that the operands are half as wide as the result and we extract
+  // half of the elements of the input vectors.
+  if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
+      (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
+    return false;
+
+  // Check the mask extracts either the lower or upper half of vector
+  // elements.
+  int M1Start = 0;
+  int M2Start = 0;
+  int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
+  if ((S1Op1 &&
+       !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
+      (S2Op1 &&
+       !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
+    return false;
+
+  if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
+      (M2Start != 0 && M2Start != (NumElements / 2)))
+    return false;
+  if (S1Op1 && S2Op1 && M1Start != M2Start)
+    return false;
+
+  return true;
+}
+
+/// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
+/// of the vector elements.
+static bool areExtractExts(Value *Ext1, Value *Ext2) {
+  auto areExtDoubled = [](Instruction *Ext) {
+    return Ext->getType()->getScalarSizeInBits() ==
+           2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
+  };
+
+  if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
+      !match(Ext2, m_ZExtOrSExt(m_Value())) ||
+      !areExtDoubled(cast<Instruction>(Ext1)) ||
+      !areExtDoubled(cast<Instruction>(Ext2)))
+    return false;
+
+  return true;
+}
+
+/// Check if Op could be used with vmull_high_p64 intrinsic.
+static bool isOperandOfVmullHighP64(Value *Op) {
+  Value *VectorOperand = nullptr;
+  ConstantInt *ElementIndex = nullptr;
+  return match(Op, m_ExtractElt(m_Value(VectorOperand),
+                                m_ConstantInt(ElementIndex))) &&
+         ElementIndex->getValue() == 1 &&
+         isa<FixedVectorType>(VectorOperand->getType()) &&
+         cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
+}
+
+/// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
+static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
+  return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
+}
+
+static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
+  // Restrict ourselves to the form CodeGenPrepare typically constructs.
+  auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
+  if (!GEP || GEP->getNumOperands() != 2)
+    return false;
+
+  Value *Base = GEP->getOperand(0);
+  Value *Offsets = GEP->getOperand(1);
+
+  // We only care about scalar_base+vector_offsets.
+  if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
+    return false;
+
+  // Sink extends that would allow us to use 32-bit offset vectors.
+  if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
+    auto *OffsetsInst = cast<Instruction>(Offsets);
+    if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
+        OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
+      Ops.push_back(&GEP->getOperandUse(1));
+  }
+
+  // Sink the GEP.
+  return true;
+}
+
+/// We want to sink following cases:
+/// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
+/// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
+static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
+  if (match(Op, m_VScale()))
+    return true;
+  if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
+      match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
+    Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
+    return true;
+  }
+  if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
+      match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
+    Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
+    Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
+    Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
+    return true;
+  }
+  return false;
+}
+
+/// Check if sinking \p I's operands to I's basic block is profitable, because
+/// the operands can be folded into a target instruction, e.g.
+/// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
+bool AArch64TTIImpl::isProfitableToSinkOperands(
+    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
+  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+    switch (II->getIntrinsicID()) {
+    case Intrinsic::aarch64_neon_smull:
+    case Intrinsic::aarch64_neon_umull:
+      if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
+                                   /*AllowSplat=*/true)) {
+        Ops.push_back(&II->getOperandUse(0));
+        Ops.push_back(&II->getOperandUse(1));
+        return true;
+      }
+      [[fallthrough]];
+
+    case Intrinsic::fma:
+    case Intrinsic::fmuladd:
+      if (isa<VectorType>(I->getType()) &&
+          cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
+          !ST->hasFullFP16())
+        return false;
+      [[fallthrough]];
+    case Intrinsic::aarch64_neon_sqdmull:
+    case Intrinsic::aarch64_neon_sqdmulh:
+    case Intrinsic::aarch64_neon_sqrdmulh:
+      // Sink splats for index lane variants
+      if (isSplatShuffle(II->getOperand(0)))
+        Ops.push_back(&II->getOperandUse(0));
+      if (isSplatShuffle(II->getOperand(1)))
+        Ops.push_back(&II->getOperandUse(1));
+      return !Ops.empty();
+    case Intrinsic::aarch64_neon_fmlal:
+    case Intrinsic::aarch64_neon_fmlal2:
+    case Intrinsic::aarch64_neon_fmlsl:
+    case Intrinsic::aarch64_neon_fmlsl2:
+      // Sink splats for index lane variants
+      if (isSplatShuffle(II->getOperand(1)))
+        Ops.push_back(&II->getOperandUse(1));
+      if (isSplatShuffle(II->getOperand(2)))
+        Ops.push_back(&II->getOperandUse(2));
+      return !Ops.empty();
+    case Intrinsic::aarch64_sve_ptest_first:
+    case Intrinsic::aarch64_sve_ptest_last:
+      if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
+        if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
+          Ops.push_back(&II->getOperandUse(0));
+      return !Ops.empty();
+    case Intrinsic::aarch64_sme_write_horiz:
+    case Intrinsic::aarch64_sme_write_vert:
+    case Intrinsic::aarch64_sme_writeq_horiz:
+    case Intrinsic::aarch64_sme_writeq_vert: {
+      auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
+      if (!Idx || Idx->getOpcode() != Instruction::Add)
+        return false;
+      Ops.push_back(&II->getOperandUse(1));
+      return true;
+    }
+    case Intrinsic::aarch64_sme_read_horiz:
+    case Intrinsic::aarch64_sme_read_vert:
+    case Intrinsic::aarch64_sme_readq_horiz:
+    case Intrinsic::aarch64_sme_readq_vert:
+    case Intrinsic::aarch64_sme_ld1b_vert:
+    case Intrinsic::aarch64_sme_ld1h_vert:
+    case Intrinsic::aarch64_sme_ld1w_vert:
+    case Intrinsic::aarch64_sme_ld1d_vert:
+    case Intrinsic::aarch64_sme_ld1q_vert:
+    case Intrinsic::aarch64_sme_st1b_vert:
+    case Intrinsic::aarch64_sme_st1h_vert:
+    case Intrinsic::aarch64_sme_st1w_vert:
+    case Intrinsic::aarch64_sme_st1d_vert:
+    case Intrinsic::aarch64_sme_st1q_vert:
+    case Intrinsic::aarch64_sme_ld1b_horiz:
+    case Intrinsic::aarch64_sme_ld1h_horiz:
+    case Intrinsic::aarch64_sme_ld1w_horiz:
+    case Intrinsic::aarch64_sme_ld1d_horiz:
+    case Intrinsic::aarch64_sme_ld1q_horiz:
+    case Intrinsic::aarch64_sme_st1b_horiz:
+    case Intrinsic::aarch64_sme_st1h_horiz:
+    case Intrinsic::aarch64_sme_st1w_horiz:
+    case Intrinsic::aarch64_sme_st1d_horiz:
+    case Intrinsic::aarch64_sme_st1q_horiz: {
+      auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
+      if (!Idx || Idx->getOpcode() != Instruction::Add)
+        return false;
+      Ops.push_back(&II->getOperandUse(3));
+      return true;
+    }
+    case Intrinsic::aarch64_neon_pmull:
+      if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
+        return false;
+      Ops.push_back(&II->getOperandUse(0));
+      Ops.push_back(&II->getOperandUse(1));
+      return true;
+    case Intrinsic::aarch64_neon_pmull64:
+      if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
+                                     II->getArgOperand(1)))
+        return false;
+      Ops.push_back(&II->getArgOperandUse(0));
+      Ops.push_back(&II->getArgOperandUse(1));
+      return true;
+    case Intrinsic::masked_gather:
+      if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
+        return false;
+      Ops.push_back(&II->getArgOperandUse(0));
+      return true;
+    case Intrinsic::masked_scatter:
+      if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
+        return false;
+      Ops.push_back(&II->getArgOperandUse(1));
+      return true;
+    default:
+      return false;
+    }
+  }
+
+  // Sink vscales closer to uses for better isel
+  switch (I->getOpcode()) {
+  case Instruction::GetElementPtr:
+  case Instruction::Add:
+  case Instruction::Sub:
+    for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
+      if (shouldSinkVScale(I->getOperand(Op), Ops)) {
+        Ops.push_back(&I->getOperandUse(Op));
+        return true;
+      }
+    }
+    break;
+  default:
+    break;
+  }
+
+  if (!I->getType()->isVectorTy())
+    return false;
+
+  switch (I->getOpcode()) {
+  case Instruction::Sub:
+  case Instruction::Add: {
+    if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
+      return false;
+
+    // If the exts' operands extract either the lower or upper elements, we
+    // can sink them too.
+    auto Ext1 = cast<Instruction>(I->getOperand(0));
+    auto Ext2 = cast<Instruction>(I->getOperand(1));
+    if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
+      Ops.push_back(&Ext1->getOperandUse(0));
+      Ops.push_back(&Ext2->getOperandUse(0));
+    }
+
+    Ops.push_back(&I->getOperandUse(0));
+    Ops.push_back(&I->getOperandUse(1));
+
+    return true;
+  }
+  case Instruction::Or: {
+    // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
+    // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
+    if (ST->hasNEON()) {
+      Instruction *OtherAnd, *IA, *IB;
+      Value *MaskValue;
+      // MainAnd refers to And instruction that has 'Not' as one of its operands
+      if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
+                          m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
+                                           m_Instruction(IA)))))) {
+        if (match(OtherAnd,
+                  m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
+          Instruction *MainAnd = I->getOperand(0) == OtherAnd
+                                     ? cast<Instruction>(I->getOperand(1))
+                                     : cast<Instruction>(I->getOperand(0));
+
+          // Both Ands should be in same basic block as Or
+          if (I->getParent() != MainAnd->getParent() ||
+              I->getParent() != OtherAnd->getParent())
+            return false;
+
+          // Non-mask operands of both Ands should also be in same basic block
+          if (I->getParent() != IA->getParent() ||
+              I->getParent() != IB->getParent())
+            return false;
+
+          Ops.push_back(
+              &MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
+          Ops.push_back(&I->getOperandUse(0));
+          Ops.push_back(&I->getOperandUse(1));
+
+          return true;
+        }
+      }
+    }
+
+    return false;
+  }
+  case Instruction::Mul: {
+    int NumZExts = 0, NumSExts = 0;
+    for (auto &Op : I->operands()) {
+      // Make sure we are not already sinking this operand
+      if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+        continue;
+
+      if (match(&Op, m_SExt(m_Value()))) {
+        NumSExts++;
+        continue;
+      } else if (match(&Op, m_ZExt(m_Value()))) {
+        NumZExts++;
+        continue;
+      }
+
+      ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
+
+      // If the Shuffle is a splat and the operand is a zext/sext, sinking the
+      // operand and the s/zext can help create indexed s/umull. This is
+      // especially useful to prevent i64 mul being scalarized.
+      if (Shuffle && isSplatShuffle(Shuffle) &&
+          match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
+        Ops.push_back(&Shuffle->getOperandUse(0));
+        Ops.push_back(&Op);
+        if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
+          NumSExts++;
+        else
+          NumZExts++;
+        continue;
+      }
+
+      if (!Shuffle)
+        continue;
+
+      Value *ShuffleOperand = Shuffle->getOperand(0);
+      InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
+      if (!Insert)
+        continue;
+
+      Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
+      if (!OperandInstr)
+        continue;
+
+      ConstantInt *ElementConstant =
+          dyn_cast<ConstantInt>(Insert->getOperand(2));
+      // Check that the insertelement is inserting into element 0
+      if (!ElementConstant || !ElementConstant->isZero())
+        continue;
+
+      unsigned Opcode = OperandInstr->getOpcode();
+      if (Opcode == Instruction::SExt)
+        NumSExts++;
+      else if (Opcode == Instruction::ZExt)
+        NumZExts++;
+      else {
+        // If we find that the top bits are known 0, then we can sink and allow
+        // the backend to generate a umull.
+        unsigned Bitwidth = I->getType()->getScalarSizeInBits();
+        APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
+        const DataLayout &DL = I->getDataLayout();
+        if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
+          continue;
+        NumZExts++;
+      }
+
+      Ops.push_back(&Shuffle->getOperandUse(0));
+      Ops.push_back(&Op);
+    }
+
+    // Is it profitable to sink if we found two of the same type of extends.
+    return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
+  }
+  default:
+    return false;
+  }
+  return false;
+}

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 28e45207596ecd..1d09d67f6ec9e3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -416,7 +416,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
                                        StackOffset BaseOffset, bool HasBaseReg,
                                        int64_t Scale, unsigned AddrSpace) const;
-  /// @}
 
   bool enableSelectOptimize() { return ST->enableSelectOptimize(); }
 
@@ -435,6 +434,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);
+
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
+  /// @}
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index cceb89e23f1290..0f65df0763cc83 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -6043,22 +6043,3 @@ bool AMDGPUTargetLowering::isReassocProfitable(MachineRegisterInfo &MRI,
                                                Register N0, Register N1) const {
   return MRI.hasOneNonDBGUse(N0); // FIXME: handle regbanks
 }
-
-/// Whether it is profitable to sink the operands of an
-/// Instruction I to the basic block of I.
-/// This helps using several modifiers (like abs and neg) more often.
-bool AMDGPUTargetLowering::shouldSinkOperands(
-    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
-  using namespace PatternMatch;
-
-  for (auto &Op : I->operands()) {
-    // Ensure we are not already sinking this operand.
-    if (any_of(Ops, [&](Use *U) { return U->get() == Op.get(); }))
-      continue;
-
-    if (match(&Op, m_FAbs(m_Value())) || match(&Op, m_FNeg(m_Value())))
-      Ops.push_back(&Op);
-  }
-
-  return !Ops.empty();
-}

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 5c2abd334276c1..b2fd31cb2346eb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -387,9 +387,6 @@ class AMDGPUTargetLowering : public TargetLowering {
   MVT getFenceOperandTy(const DataLayout &DL) const override {
     return MVT::i32;
   }
-
-  bool shouldSinkOperands(Instruction *I,
-                          SmallVectorImpl<Use *> &Ops) const override;
 };
 
 namespace AMDGPUISD {

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index 0a2d4e6494305f..3f4f42377d56ee 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -1183,6 +1183,25 @@ InstructionCost GCNTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
   return BaseT::getShuffleCost(Kind, VT, Mask, CostKind, Index, SubTp);
 }
 
+/// Whether it is profitable to sink the operands of an
+/// Instruction I to the basic block of I.
+/// This helps using several modifiers (like abs and neg) more often.
+bool GCNTTIImpl::isProfitableToSinkOperands(Instruction *I,
+                                            SmallVectorImpl<Use *> &Ops) const {
+  using namespace PatternMatch;
+
+  for (auto &Op : I->operands()) {
+    // Ensure we are not already sinking this operand.
+    if (any_of(Ops, [&](Use *U) { return U->get() == Op.get(); }))
+      continue;
+
+    if (match(&Op, m_FAbs(m_Value())) || match(&Op, m_FNeg(m_Value())))
+      Ops.push_back(&Op);
+  }
+
+  return !Ops.empty();
+}
+
 bool GCNTTIImpl::areInlineCompatible(const Function *Caller,
                                      const Function *Callee) const {
   const TargetMachine &TM = getTLI()->getTargetMachine();

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index 76785ee456a417..30da002376251c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -237,6 +237,9 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
                                  ArrayRef<const Value *> Args = {},
                                  const Instruction *CxtI = nullptr);
 
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
+
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 1733424a8b669f..bf757edfa85890 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -19283,149 +19283,6 @@ bool ARMTargetLowering::isFNegFree(EVT VT) const {
   return false;
 }
 
-/// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
-/// of the vector elements.
-static bool areExtractExts(Value *Ext1, Value *Ext2) {
-  auto areExtDoubled = [](Instruction *Ext) {
-    return Ext->getType()->getScalarSizeInBits() ==
-           2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
-  };
-
-  if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
-      !match(Ext2, m_ZExtOrSExt(m_Value())) ||
-      !areExtDoubled(cast<Instruction>(Ext1)) ||
-      !areExtDoubled(cast<Instruction>(Ext2)))
-    return false;
-
-  return true;
-}
-
-/// Check if sinking \p I's operands to I's basic block is profitable, because
-/// the operands can be folded into a target instruction, e.g.
-/// sext/zext can be folded into vsubl.
-bool ARMTargetLowering::shouldSinkOperands(Instruction *I,
-                                           SmallVectorImpl<Use *> &Ops) const {
-  if (!I->getType()->isVectorTy())
-    return false;
-
-  if (Subtarget->hasNEON()) {
-    switch (I->getOpcode()) {
-    case Instruction::Sub:
-    case Instruction::Add: {
-      if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
-        return false;
-      Ops.push_back(&I->getOperandUse(0));
-      Ops.push_back(&I->getOperandUse(1));
-      return true;
-    }
-    default:
-      return false;
-    }
-  }
-
-  if (!Subtarget->hasMVEIntegerOps())
-    return false;
-
-  auto IsFMSMul = [&](Instruction *I) {
-    if (!I->hasOneUse())
-      return false;
-    auto *Sub = cast<Instruction>(*I->users().begin());
-    return Sub->getOpcode() == Instruction::FSub && Sub->getOperand(1) == I;
-  };
-  auto IsFMS = [&](Instruction *I) {
-    if (match(I->getOperand(0), m_FNeg(m_Value())) ||
-        match(I->getOperand(1), m_FNeg(m_Value())))
-      return true;
-    return false;
-  };
-
-  auto IsSinker = [&](Instruction *I, int Operand) {
-    switch (I->getOpcode()) {
-    case Instruction::Add:
-    case Instruction::Mul:
-    case Instruction::FAdd:
-    case Instruction::ICmp:
-    case Instruction::FCmp:
-      return true;
-    case Instruction::FMul:
-      return !IsFMSMul(I);
-    case Instruction::Sub:
-    case Instruction::FSub:
-    case Instruction::Shl:
-    case Instruction::LShr:
-    case Instruction::AShr:
-      return Operand == 1;
-    case Instruction::Call:
-      if (auto *II = dyn_cast<IntrinsicInst>(I)) {
-        switch (II->getIntrinsicID()) {
-        case Intrinsic::fma:
-          return !IsFMS(I);
-        case Intrinsic::sadd_sat:
-        case Intrinsic::uadd_sat:
-        case Intrinsic::arm_mve_add_predicated:
-        case Intrinsic::arm_mve_mul_predicated:
-        case Intrinsic::arm_mve_qadd_predicated:
-        case Intrinsic::arm_mve_vhadd:
-        case Intrinsic::arm_mve_hadd_predicated:
-        case Intrinsic::arm_mve_vqdmull:
-        case Intrinsic::arm_mve_vqdmull_predicated:
-        case Intrinsic::arm_mve_vqdmulh:
-        case Intrinsic::arm_mve_qdmulh_predicated:
-        case Intrinsic::arm_mve_vqrdmulh:
-        case Intrinsic::arm_mve_qrdmulh_predicated:
-        case Intrinsic::arm_mve_fma_predicated:
-          return true;
-        case Intrinsic::ssub_sat:
-        case Intrinsic::usub_sat:
-        case Intrinsic::arm_mve_sub_predicated:
-        case Intrinsic::arm_mve_qsub_predicated:
-        case Intrinsic::arm_mve_hsub_predicated:
-        case Intrinsic::arm_mve_vhsub:
-          return Operand == 1;
-        default:
-          return false;
-        }
-      }
-      return false;
-    default:
-      return false;
-    }
-  };
-
-  for (auto OpIdx : enumerate(I->operands())) {
-    Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
-    // Make sure we are not already sinking this operand
-    if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
-      continue;
-
-    Instruction *Shuffle = Op;
-    if (Shuffle->getOpcode() == Instruction::BitCast)
-      Shuffle = dyn_cast<Instruction>(Shuffle->getOperand(0));
-    // We are looking for a splat that can be sunk.
-    if (!Shuffle ||
-        !match(Shuffle, m_Shuffle(
-                            m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
-                            m_Undef(), m_ZeroMask())))
-      continue;
-    if (!IsSinker(I, OpIdx.index()))
-      continue;
-
-    // All uses of the shuffle should be sunk to avoid duplicating it across gpr
-    // and vector registers
-    for (Use &U : Op->uses()) {
-      Instruction *Insn = cast<Instruction>(U.getUser());
-      if (!IsSinker(Insn, U.getOperandNo()))
-        return false;
-    }
-
-    Ops.push_back(&Shuffle->getOperandUse(0));
-    if (Shuffle != Op)
-      Ops.push_back(&Op->getOperandUse(0));
-    Ops.push_back(&OpIdx.value());
-  }
-  return true;
-}
-
 Type *ARMTargetLowering::shouldConvertSplatType(ShuffleVectorInst *SVI) const {
   if (!Subtarget->hasMVEIntegerOps())
     return nullptr;

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index a255e9b6fc365f..316f7d3b9bce5d 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -460,8 +460,6 @@ class VectorType;
     bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
     bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
     bool isZExtFree(SDValue Val, EVT VT2) const override;
-    bool shouldSinkOperands(Instruction *I,
-                            SmallVectorImpl<Use *> &Ops) const override;
     Type* shouldConvertSplatType(ShuffleVectorInst* SVI) const override;
 
     bool isFNegFree(EVT VT) const override;

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 865e2f3066ef01..835ae98efb852d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2659,3 +2659,149 @@ bool ARMTTIImpl::hasArmWideBranch(bool Thumb) const {
     return ST->hasARMOps();
   }
 }
+
+/// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
+/// of the vector elements.
+static bool areExtractExts(Value *Ext1, Value *Ext2) {
+  using namespace PatternMatch;
+
+  auto areExtDoubled = [](Instruction *Ext) {
+    return Ext->getType()->getScalarSizeInBits() ==
+           2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
+  };
+
+  if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
+      !match(Ext2, m_ZExtOrSExt(m_Value())) ||
+      !areExtDoubled(cast<Instruction>(Ext1)) ||
+      !areExtDoubled(cast<Instruction>(Ext2)))
+    return false;
+
+  return true;
+}
+
+/// Check if sinking \p I's operands to I's basic block is profitable, because
+/// the operands can be folded into a target instruction, e.g.
+/// sext/zext can be folded into vsubl.
+bool ARMTTIImpl::isProfitableToSinkOperands(Instruction *I,
+                                            SmallVectorImpl<Use *> &Ops) const {
+  using namespace PatternMatch;
+
+  if (!I->getType()->isVectorTy())
+    return false;
+
+  if (ST->hasNEON()) {
+    switch (I->getOpcode()) {
+    case Instruction::Sub:
+    case Instruction::Add: {
+      if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
+        return false;
+      Ops.push_back(&I->getOperandUse(0));
+      Ops.push_back(&I->getOperandUse(1));
+      return true;
+    }
+    default:
+      return false;
+    }
+  }
+
+  if (!ST->hasMVEIntegerOps())
+    return false;
+
+  auto IsFMSMul = [&](Instruction *I) {
+    if (!I->hasOneUse())
+      return false;
+    auto *Sub = cast<Instruction>(*I->users().begin());
+    return Sub->getOpcode() == Instruction::FSub && Sub->getOperand(1) == I;
+  };
+  auto IsFMS = [&](Instruction *I) {
+    if (match(I->getOperand(0), m_FNeg(m_Value())) ||
+        match(I->getOperand(1), m_FNeg(m_Value())))
+      return true;
+    return false;
+  };
+
+  auto IsSinker = [&](Instruction *I, int Operand) {
+    switch (I->getOpcode()) {
+    case Instruction::Add:
+    case Instruction::Mul:
+    case Instruction::FAdd:
+    case Instruction::ICmp:
+    case Instruction::FCmp:
+      return true;
+    case Instruction::FMul:
+      return !IsFMSMul(I);
+    case Instruction::Sub:
+    case Instruction::FSub:
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+      return Operand == 1;
+    case Instruction::Call:
+      if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+        switch (II->getIntrinsicID()) {
+        case Intrinsic::fma:
+          return !IsFMS(I);
+        case Intrinsic::sadd_sat:
+        case Intrinsic::uadd_sat:
+        case Intrinsic::arm_mve_add_predicated:
+        case Intrinsic::arm_mve_mul_predicated:
+        case Intrinsic::arm_mve_qadd_predicated:
+        case Intrinsic::arm_mve_vhadd:
+        case Intrinsic::arm_mve_hadd_predicated:
+        case Intrinsic::arm_mve_vqdmull:
+        case Intrinsic::arm_mve_vqdmull_predicated:
+        case Intrinsic::arm_mve_vqdmulh:
+        case Intrinsic::arm_mve_qdmulh_predicated:
+        case Intrinsic::arm_mve_vqrdmulh:
+        case Intrinsic::arm_mve_qrdmulh_predicated:
+        case Intrinsic::arm_mve_fma_predicated:
+          return true;
+        case Intrinsic::ssub_sat:
+        case Intrinsic::usub_sat:
+        case Intrinsic::arm_mve_sub_predicated:
+        case Intrinsic::arm_mve_qsub_predicated:
+        case Intrinsic::arm_mve_hsub_predicated:
+        case Intrinsic::arm_mve_vhsub:
+          return Operand == 1;
+        default:
+          return false;
+        }
+      }
+      return false;
+    default:
+      return false;
+    }
+  };
+
+  for (auto OpIdx : enumerate(I->operands())) {
+    Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
+    // Make sure we are not already sinking this operand
+    if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+      continue;
+
+    Instruction *Shuffle = Op;
+    if (Shuffle->getOpcode() == Instruction::BitCast)
+      Shuffle = dyn_cast<Instruction>(Shuffle->getOperand(0));
+    // We are looking for a splat that can be sunk.
+    if (!Shuffle || !match(Shuffle, m_Shuffle(m_InsertElt(m_Undef(), m_Value(),
+                                                          m_ZeroInt()),
+                                              m_Undef(), m_ZeroMask())))
+      continue;
+    if (!IsSinker(I, OpIdx.index()))
+      continue;
+
+    // All uses of the shuffle should be sunk to avoid duplicating it across gpr
+    // and vector registers
+    for (Use &U : Op->uses()) {
+      Instruction *Insn = cast<Instruction>(U.getUser());
+      if (!IsSinker(Insn, U.getOperandNo()))
+        return false;
+    }
+
+    Ops.push_back(&Shuffle->getOperandUse(0));
+    if (Shuffle != Op)
+      Ops.push_back(&Op->getOperandUse(0));
+    Ops.push_back(&OpIdx.value());
+  }
+  return true;
+}

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 7be53c4bcaa295..b0a75134ee02b7 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -335,6 +335,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool hasArmWideBranch(bool Thumb) const;
 
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
   /// @}
 };
 

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 463887b8b55e61..01fa418e4dbdf4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2068,145 +2068,6 @@ bool RISCVTargetLowering::
   return !XC;
 }
 
-bool RISCVTargetLowering::canSplatOperand(unsigned Opcode, int Operand) const {
-  switch (Opcode) {
-  case Instruction::Add:
-  case Instruction::Sub:
-  case Instruction::Mul:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor:
-  case Instruction::FAdd:
-  case Instruction::FSub:
-  case Instruction::FMul:
-  case Instruction::FDiv:
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-    return true;
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::URem:
-  case Instruction::SRem:
-  case Instruction::Select:
-    return Operand == 1;
-  default:
-    return false;
-  }
-}
-
-
-bool RISCVTargetLowering::canSplatOperand(Instruction *I, int Operand) const {
-  if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
-    return false;
-
-  if (canSplatOperand(I->getOpcode(), Operand))
-    return true;
-
-  auto *II = dyn_cast<IntrinsicInst>(I);
-  if (!II)
-    return false;
-
-  switch (II->getIntrinsicID()) {
-  case Intrinsic::fma:
-  case Intrinsic::vp_fma:
-    return Operand == 0 || Operand == 1;
-  case Intrinsic::vp_shl:
-  case Intrinsic::vp_lshr:
-  case Intrinsic::vp_ashr:
-  case Intrinsic::vp_udiv:
-  case Intrinsic::vp_sdiv:
-  case Intrinsic::vp_urem:
-  case Intrinsic::vp_srem:
-  case Intrinsic::ssub_sat:
-  case Intrinsic::vp_ssub_sat:
-  case Intrinsic::usub_sat:
-  case Intrinsic::vp_usub_sat:
-    return Operand == 1;
-    // These intrinsics are commutative.
-  case Intrinsic::vp_add:
-  case Intrinsic::vp_mul:
-  case Intrinsic::vp_and:
-  case Intrinsic::vp_or:
-  case Intrinsic::vp_xor:
-  case Intrinsic::vp_fadd:
-  case Intrinsic::vp_fmul:
-  case Intrinsic::vp_icmp:
-  case Intrinsic::vp_fcmp:
-  case Intrinsic::smin:
-  case Intrinsic::vp_smin:
-  case Intrinsic::umin:
-  case Intrinsic::vp_umin:
-  case Intrinsic::smax:
-  case Intrinsic::vp_smax:
-  case Intrinsic::umax:
-  case Intrinsic::vp_umax:
-  case Intrinsic::sadd_sat:
-  case Intrinsic::vp_sadd_sat:
-  case Intrinsic::uadd_sat:
-  case Intrinsic::vp_uadd_sat:
-    // These intrinsics have 'vr' versions.
-  case Intrinsic::vp_sub:
-  case Intrinsic::vp_fsub:
-  case Intrinsic::vp_fdiv:
-    return Operand == 0 || Operand == 1;
-  default:
-    return false;
-  }
-}
-
-/// Check if sinking \p I's operands to I's basic block is profitable, because
-/// the operands can be folded into a target instruction, e.g.
-/// splats of scalars can fold into vector instructions.
-bool RISCVTargetLowering::shouldSinkOperands(
-    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
-  using namespace llvm::PatternMatch;
-
-  if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions())
-    return false;
-
-  // Don't sink splat operands if the target prefers it. Some targets requires
-  // S2V transfer buffers and we can run out of them copying the same value
-  // repeatedly.
-  // FIXME: It could still be worth doing if it would improve vector register
-  // pressure and prevent a vector spill.
-  if (!Subtarget.sinkSplatOperands())
-    return false;
-
-  for (auto OpIdx : enumerate(I->operands())) {
-    if (!canSplatOperand(I, OpIdx.index()))
-      continue;
-
-    Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
-    // Make sure we are not already sinking this operand
-    if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
-      continue;
-
-    // We are looking for a splat that can be sunk.
-    if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
-                             m_Undef(), m_ZeroMask())))
-      continue;
-
-    // Don't sink i1 splats.
-    if (cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(1))
-      continue;
-
-    // All uses of the shuffle should be sunk to avoid duplicating it across gpr
-    // and vector registers
-    for (Use &U : Op->uses()) {
-      Instruction *Insn = cast<Instruction>(U.getUser());
-      if (!canSplatOperand(Insn, U.getOperandNo()))
-        return false;
-    }
-
-    Ops.push_back(&Op->getOperandUse(0));
-    Ops.push_back(&OpIdx.value());
-  }
-  return true;
-}
-
 bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
   unsigned Opc = VecOp.getOpcode();
 

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 05581552ab6041..3864d58a129e98 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -520,14 +520,6 @@ class RISCVTargetLowering : public TargetLowering {
       SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
       unsigned OldShiftOpcode, unsigned NewShiftOpcode,
       SelectionDAG &DAG) const override;
-  /// Return true if the (vector) instruction I will be lowered to an instruction
-  /// with a scalar splat operand for the given Operand number.
-  bool canSplatOperand(Instruction *I, int Operand) const;
-  /// Return true if a vector instruction will lower to a target instruction
-  /// able to splat the given operand.
-  bool canSplatOperand(unsigned Opcode, int Operand) const;
-  bool shouldSinkOperands(Instruction *I,
-                          SmallVectorImpl<Use *> &Ops) const override;
   bool shouldScalarizeBinop(SDValue VecOp) const override;
   bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override;
   int getLegalZfaFPImm(const APFloat &Imm, EVT VT) const;

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index a61461681f79ed..8d18fd63e4a2e1 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1979,8 +1979,8 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
   }
 
   auto getConstantMatCost =
-    [&](unsigned Operand, TTI::OperandValueInfo OpInfo) -> InstructionCost {
-    if (OpInfo.isUniform() && TLI->canSplatOperand(Opcode, Operand))
+      [&](unsigned Operand, TTI::OperandValueInfo OpInfo) -> InstructionCost {
+    if (OpInfo.isUniform() && canSplatOperand(Opcode, Operand))
       // Two sub-cases:
       // * Has a 5 bit immediate operand which can be splatted.
       // * Has a larger immediate which must be materialized in scalar register
@@ -2294,3 +2294,141 @@ bool RISCVTTIImpl::shouldConsiderAddressTypePromotion(
   }
   return Considerable;
 }
+
+bool RISCVTTIImpl::canSplatOperand(unsigned Opcode, int Operand) const {
+  switch (Opcode) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::FAdd:
+  case Instruction::FSub:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+    return true;
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::URem:
+  case Instruction::SRem:
+  case Instruction::Select:
+    return Operand == 1;
+  default:
+    return false;
+  }
+}
+
+bool RISCVTTIImpl::canSplatOperand(Instruction *I, int Operand) const {
+  if (!I->getType()->isVectorTy() || !ST->hasVInstructions())
+    return false;
+
+  if (canSplatOperand(I->getOpcode(), Operand))
+    return true;
+
+  auto *II = dyn_cast<IntrinsicInst>(I);
+  if (!II)
+    return false;
+
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::fma:
+  case Intrinsic::vp_fma:
+    return Operand == 0 || Operand == 1;
+  case Intrinsic::vp_shl:
+  case Intrinsic::vp_lshr:
+  case Intrinsic::vp_ashr:
+  case Intrinsic::vp_udiv:
+  case Intrinsic::vp_sdiv:
+  case Intrinsic::vp_urem:
+  case Intrinsic::vp_srem:
+  case Intrinsic::ssub_sat:
+  case Intrinsic::vp_ssub_sat:
+  case Intrinsic::usub_sat:
+  case Intrinsic::vp_usub_sat:
+    return Operand == 1;
+    // These intrinsics are commutative.
+  case Intrinsic::vp_add:
+  case Intrinsic::vp_mul:
+  case Intrinsic::vp_and:
+  case Intrinsic::vp_or:
+  case Intrinsic::vp_xor:
+  case Intrinsic::vp_fadd:
+  case Intrinsic::vp_fmul:
+  case Intrinsic::vp_icmp:
+  case Intrinsic::vp_fcmp:
+  case Intrinsic::smin:
+  case Intrinsic::vp_smin:
+  case Intrinsic::umin:
+  case Intrinsic::vp_umin:
+  case Intrinsic::smax:
+  case Intrinsic::vp_smax:
+  case Intrinsic::umax:
+  case Intrinsic::vp_umax:
+  case Intrinsic::sadd_sat:
+  case Intrinsic::vp_sadd_sat:
+  case Intrinsic::uadd_sat:
+  case Intrinsic::vp_uadd_sat:
+    // These intrinsics have 'vr' versions.
+  case Intrinsic::vp_sub:
+  case Intrinsic::vp_fsub:
+  case Intrinsic::vp_fdiv:
+    return Operand == 0 || Operand == 1;
+  default:
+    return false;
+  }
+}
+
+/// Check if sinking \p I's operands to I's basic block is profitable, because
+/// the operands can be folded into a target instruction, e.g.
+/// splats of scalars can fold into vector instructions.
+bool RISCVTTIImpl::isProfitableToSinkOperands(
+    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
+  using namespace llvm::PatternMatch;
+
+  if (!I->getType()->isVectorTy() || !ST->hasVInstructions())
+    return false;
+
+  // Don't sink splat operands if the target prefers it. Some targets requires
+  // S2V transfer buffers and we can run out of them copying the same value
+  // repeatedly.
+  // FIXME: It could still be worth doing if it would improve vector register
+  // pressure and prevent a vector spill.
+  if (!ST->sinkSplatOperands())
+    return false;
+
+  for (auto OpIdx : enumerate(I->operands())) {
+    if (!canSplatOperand(I, OpIdx.index()))
+      continue;
+
+    Instruction *Op = dyn_cast<Instruction>(OpIdx.value().get());
+    // Make sure we are not already sinking this operand
+    if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+      continue;
+
+    // We are looking for a splat that can be sunk.
+    if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()),
+                             m_Undef(), m_ZeroMask())))
+      continue;
+
+    // Don't sink i1 splats.
+    if (cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(1))
+      continue;
+
+    // All uses of the shuffle should be sunk to avoid duplicating it across gpr
+    // and vector registers
+    for (Use &U : Op->uses()) {
+      Instruction *Insn = cast<Instruction>(U.getUser());
+      if (!canSplatOperand(Insn, U.getOperandNo()))
+        return false;
+    }
+
+    Ops.push_back(&Op->getOperandUse(0));
+    Ops.push_back(&OpIdx.value());
+  }
+  return true;
+}
\ No newline at end of file

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 65bbd905508557..3f50bd86b9b3b6 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -412,6 +412,15 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   shouldConsiderAddressTypePromotion(const Instruction &I,
                                      bool &AllowPromotionWithoutCommonHeader);
   std::optional<unsigned> getMinPageSize() const { return 4096; }
+  /// Return true if the (vector) instruction I will be lowered to an
+  /// instruction with a scalar splat operand for the given Operand number.
+  bool canSplatOperand(Instruction *I, int Operand) const;
+  /// Return true if a vector instruction will lower to a target instruction
+  /// able to splat the given operand.
+  bool canSplatOperand(unsigned Opcode, int Operand) const;
+
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index fa78bf38f426cd..5f76d666823e28 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -843,30 +843,6 @@ bool WebAssemblyTargetLowering::isOffsetFoldingLegal(
   return isa<Function>(GV) ? false : TargetLowering::isOffsetFoldingLegal(GA);
 }
 
-bool WebAssemblyTargetLowering::shouldSinkOperands(
-    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
-  using namespace llvm::PatternMatch;
-
-  if (!I->getType()->isVectorTy() || !I->isShift())
-    return false;
-
-  Value *V = I->getOperand(1);
-  // We dont need to sink constant splat.
-  if (dyn_cast<Constant>(V))
-    return false;
-
-  if (match(V, m_Shuffle(m_InsertElt(m_Value(), m_Value(), m_ZeroInt()),
-                         m_Value(), m_ZeroMask()))) {
-    // Sink insert
-    Ops.push_back(&cast<Instruction>(V)->getOperandUse(0));
-    // Sink shuffle
-    Ops.push_back(&I->getOperandUse(1));
-    return true;
-  }
-
-  return false;
-}
-
 EVT WebAssemblyTargetLowering::getSetCCResultType(const DataLayout &DL,
                                                   LLVMContext &C,
                                                   EVT VT) const {

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 7d9cfb7739e435..139b064aa04230 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -76,8 +76,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   bool isIntDivCheap(EVT VT, AttributeList Attr) const override;
   bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
   bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override;
-  bool shouldSinkOperands(Instruction *I,
-                          SmallVectorImpl<Use *> &Ops) const override;
   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context,
                          EVT VT) const override;
   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index b109594811d97f..9fe5e5f27f8dad 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -154,3 +154,27 @@ void WebAssemblyTTIImpl::getUnrollingPreferences(
 bool WebAssemblyTTIImpl::supportsTailCalls() const {
   return getST()->hasTailCall();
 }
+
+bool WebAssemblyTTIImpl::isProfitableToSinkOperands(
+    Instruction *I, SmallVectorImpl<Use *> &Ops) const {
+  using namespace llvm::PatternMatch;
+
+  if (!I->getType()->isVectorTy() || !I->isShift())
+    return false;
+
+  Value *V = I->getOperand(1);
+  // We dont need to sink constant splat.
+  if (dyn_cast<Constant>(V))
+    return false;
+
+  if (match(V, m_Shuffle(m_InsertElt(m_Value(), m_Value(), m_ZeroInt()),
+                         m_Value(), m_ZeroMask()))) {
+    // Sink insert
+    Ops.push_back(&cast<Instruction>(V)->getOperandUse(0));
+    // Sink shuffle
+    Ops.push_back(&I->getOperandUse(1));
+    return true;
+  }
+
+  return false;
+}

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index ac3a333991684d..2ce6cbf3ba0266 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -71,12 +71,16 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
 
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
-  /// @}
 
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 
   bool supportsTailCalls() const;
+
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
+
+  /// @}
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ddbe82b1de5cfc..70f06b8d3a5f27 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -34671,29 +34671,6 @@ bool X86TargetLowering::isLegalAddressingMode(const DataLayout &DL,
   return true;
 }
 
-bool X86TargetLowering::isVectorShiftByScalarCheap(Type *Ty) const {
-  unsigned Bits = Ty->getScalarSizeInBits();
-
-  // XOP has v16i8/v8i16/v4i32/v2i64 variable vector shifts.
-  // Splitting for v32i8/v16i16 on XOP+AVX2 targets is still preferred.
-  if (Subtarget.hasXOP() &&
-      (Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64))
-    return false;
-
-  // AVX2 has vpsllv[dq] instructions (and other shifts) that make variable
-  // shifts just as cheap as scalar ones.
-  if (Subtarget.hasAVX2() && (Bits == 32 || Bits == 64))
-    return false;
-
-  // AVX512BW has shifts such as vpsllvw.
-  if (Subtarget.hasBWI() && Bits == 16)
-    return false;
-
-  // Otherwise, it's significantly cheaper to shift by a scalar amount than by a
-  // fully general vector.
-  return true;
-}
-
 bool X86TargetLowering::isBinOp(unsigned Opcode) const {
   switch (Opcode) {
   // These are non-commutative binops.
@@ -34808,63 +34785,6 @@ bool X86TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
   return false;
 }
 
-bool X86TargetLowering::shouldSinkOperands(Instruction *I,
-                                           SmallVectorImpl<Use *> &Ops) const {
-  using namespace llvm::PatternMatch;
-
-  FixedVectorType *VTy = dyn_cast<FixedVectorType>(I->getType());
-  if (!VTy)
-    return false;
-
-  if (I->getOpcode() == Instruction::Mul &&
-      VTy->getElementType()->isIntegerTy(64)) {
-    for (auto &Op : I->operands()) {
-      // Make sure we are not already sinking this operand
-      if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
-        continue;
-
-      // Look for PMULDQ pattern where the input is a sext_inreg from vXi32 or
-      // the PMULUDQ pattern where the input is a zext_inreg from vXi32.
-      if (Subtarget.hasSSE41() &&
-          match(Op.get(), m_AShr(m_Shl(m_Value(), m_SpecificInt(32)),
-                                 m_SpecificInt(32)))) {
-        Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
-        Ops.push_back(&Op);
-      } else if (Subtarget.hasSSE2() &&
-                 match(Op.get(),
-                       m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) {
-        Ops.push_back(&Op);
-      }
-    }
-
-    return !Ops.empty();
-  }
-
-  // A uniform shift amount in a vector shift or funnel shift may be much
-  // cheaper than a generic variable vector shift, so make that pattern visible
-  // to SDAG by sinking the shuffle instruction next to the shift.
-  int ShiftAmountOpNum = -1;
-  if (I->isShift())
-    ShiftAmountOpNum = 1;
-  else if (auto *II = dyn_cast<IntrinsicInst>(I)) {
-    if (II->getIntrinsicID() == Intrinsic::fshl ||
-        II->getIntrinsicID() == Intrinsic::fshr)
-      ShiftAmountOpNum = 2;
-  }
-
-  if (ShiftAmountOpNum == -1)
-    return false;
-
-  auto *Shuf = dyn_cast<ShuffleVectorInst>(I->getOperand(ShiftAmountOpNum));
-  if (Shuf && getSplatIndex(Shuf->getShuffleMask()) >= 0 &&
-      isVectorShiftByScalarCheap(I->getType())) {
-    Ops.push_back(&I->getOperandUse(ShiftAmountOpNum));
-    return true;
-  }
-
-  return false;
-}
-
 bool X86TargetLowering::shouldConvertPhiType(Type *From, Type *To) const {
   if (!Subtarget.is64Bit())
     return false;

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 0ab42f032c3ea6..a2515ff35e6925 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1404,10 +1404,6 @@ namespace llvm {
 
     bool isLegalStoreImmediate(int64_t Imm) const override;
 
-    /// This is used to enable splatted operand transforms for vector shifts
-    /// and vector funnel shifts.
-    bool isVectorShiftByScalarCheap(Type *Ty) const override;
-
     /// Add x86-specific opcodes to the default list.
     bool isBinOp(unsigned Opcode) const override;
 
@@ -1434,8 +1430,6 @@ namespace llvm {
     bool isZExtFree(EVT VT1, EVT VT2) const override;
     bool isZExtFree(SDValue Val, EVT VT2) const override;
 
-    bool shouldSinkOperands(Instruction *I,
-                            SmallVectorImpl<Use *> &Ops) const override;
     bool shouldConvertPhiType(Type *From, Type *To) const override;
 
     /// Return true if folding a vector load into ExtVal (a sign, zero, or any

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index aa84e3887c3890..413ef0136d5c06 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6900,3 +6900,82 @@ InstructionCost X86TTIImpl::getBranchMispredictPenalty() const {
   // TODO: Hook MispredictPenalty of SchedMachineModel into this.
   return 14;
 }
+
+bool X86TTIImpl::isVectorShiftByScalarCheap(Type *Ty) const {
+  unsigned Bits = Ty->getScalarSizeInBits();
+
+  // XOP has v16i8/v8i16/v4i32/v2i64 variable vector shifts.
+  // Splitting for v32i8/v16i16 on XOP+AVX2 targets is still preferred.
+  if (ST->hasXOP() && (Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64))
+    return false;
+
+  // AVX2 has vpsllv[dq] instructions (and other shifts) that make variable
+  // shifts just as cheap as scalar ones.
+  if (ST->hasAVX2() && (Bits == 32 || Bits == 64))
+    return false;
+
+  // AVX512BW has shifts such as vpsllvw.
+  if (ST->hasBWI() && Bits == 16)
+    return false;
+
+  // Otherwise, it's significantly cheaper to shift by a scalar amount than by a
+  // fully general vector.
+  return true;
+}
+
+bool X86TTIImpl::isProfitableToSinkOperands(Instruction *I,
+                                            SmallVectorImpl<Use *> &Ops) const {
+  using namespace llvm::PatternMatch;
+
+  FixedVectorType *VTy = dyn_cast<FixedVectorType>(I->getType());
+  if (!VTy)
+    return false;
+
+  if (I->getOpcode() == Instruction::Mul &&
+      VTy->getElementType()->isIntegerTy(64)) {
+    for (auto &Op : I->operands()) {
+      // Make sure we are not already sinking this operand
+      if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+        continue;
+
+      // Look for PMULDQ pattern where the input is a sext_inreg from vXi32 or
+      // the PMULUDQ pattern where the input is a zext_inreg from vXi32.
+      if (ST->hasSSE41() &&
+          match(Op.get(), m_AShr(m_Shl(m_Value(), m_SpecificInt(32)),
+                                 m_SpecificInt(32)))) {
+        Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
+        Ops.push_back(&Op);
+      } else if (ST->hasSSE2() &&
+                 match(Op.get(),
+                       m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) {
+        Ops.push_back(&Op);
+      }
+    }
+
+    return !Ops.empty();
+  }
+
+  // A uniform shift amount in a vector shift or funnel shift may be much
+  // cheaper than a generic variable vector shift, so make that pattern visible
+  // to SDAG by sinking the shuffle instruction next to the shift.
+  int ShiftAmountOpNum = -1;
+  if (I->isShift())
+    ShiftAmountOpNum = 1;
+  else if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    if (II->getIntrinsicID() == Intrinsic::fshl ||
+        II->getIntrinsicID() == Intrinsic::fshr)
+      ShiftAmountOpNum = 2;
+  }
+
+  if (ShiftAmountOpNum == -1)
+    return false;
+
+  auto *Shuf = dyn_cast<ShuffleVectorInst>(I->getOperand(ShiftAmountOpNum));
+  if (Shuf && getSplatIndex(Shuf->getShuffleMask()) >= 0 &&
+      isVectorShiftByScalarCheap(I->getType())) {
+    Ops.push_back(&I->getOperandUse(ShiftAmountOpNum));
+    return true;
+  }
+
+  return false;
+}

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index c16461b157e07f..0100f328ab4bd3 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -297,6 +297,11 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
 
   InstructionCost getBranchMispredictPenalty() const;
 
+  bool isProfitableToSinkOperands(Instruction *I,
+                                  SmallVectorImpl<Use *> &Ops) const;
+
+  bool isVectorShiftByScalarCheap(Type *Ty) const;
+
 private:
   bool supportsGather() const;
   InstructionCost getGSVectorCost(unsigned Opcode, TTI::TargetCostKind CostKind,


        


More information about the llvm-commits mailing list