[llvm] [X86] Support hoisting load/store with conditional faulting (PR #95515)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 04:08:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-x86

Author: Shengchen Kan (KanRobert)

<details>
<summary>Changes</summary>

1. Add TTI interface for conditional load/store.
2. Hoist load/store from successors with masked load/store if
   the targets support conditional faulting.
3. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
   legalized in pass scalarize-masked-mem-intrin.
3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
   conditional load/store node to avoid error in
   DAGTypeLegalizer::ScalarizeVectorResult.
4. Lower conditional load/store to CFCMOV.

---

Patch is 41.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95515.diff


14 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1) 
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+12) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+26-6) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+50) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.h (+16) 
- (modified) llvm/lib/Target/X86/X86InstrCMovSetCC.td (+29) 
- (modified) llvm/lib/Target/X86/X86InstrFragments.td (+12) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+28-4) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+1) 
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+187-3) 
- (added) llvm/test/CodeGen/X86/apx/cf.ll (+90) 
- (added) llvm/test/Transforms/SimplifyCFG/X86/hoist-load-store-with-cf.ll (+357) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..37afda39a1c9c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
   /// \return the number of registers in the target-provided register class.
   unsigned getNumberOfRegisters(unsigned ClassID) const;
 
+  /// \return true if the target supports load/store that enables fault
+  /// suppression of memory operands when the source condition is false.
+  bool hasConditionalFaultingLoadStoreForType(Type *Ty) const;
+
   /// \return the target-provided register class ID for the provided type,
   /// accounting for type promotion and other type-legalization techniques that
   /// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
   virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
                                              const Function &Fn) const = 0;
   virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+  virtual bool hasConditionalFaultingLoadStoreForType(Type *Ty) const = 0;
   virtual unsigned getRegisterClassForType(bool Vector,
                                            Type *Ty = nullptr) const = 0;
   virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getNumberOfRegisters(unsigned ClassID) const override {
     return Impl.getNumberOfRegisters(ClassID);
   }
+  bool hasConditionalFaultingLoadStoreForType(Type *Ty) const override {
+    return Impl.hasConditionalFaultingLoadStoreForType(Ty);
+  }
   unsigned getRegisterClassForType(bool Vector,
                                    Type *Ty = nullptr) const override {
     return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..a4aa836ed82d3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+  bool hasConditionalFaultingLoadStoreForType(Type *Ty) const { return false; }
 
   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
     return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..2b0a45133bb0e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,18 @@ class TargetLowering : public TargetLoweringBase {
                            const SDValue OldRHS, SDValue &Chain,
                            bool IsSignaling = false) const;
 
+  virtual SDValue visitMaskedLoadForCondFaulting(
+      SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+      SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
+  virtual SDValue visitMaskedStoreForCondFaulting(
+      SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+      SDValue Ptr, SDValue Val, SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
   /// Returns a pair of (return value, chain).
   /// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
   std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..46936f266bf46 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,11 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
   return TTIImpl->getNumberOfRegisters(ClassID);
 }
 
+bool TargetTransformInfo::hasConditionalFaultingLoadStoreForType(
+    Type *Ty) const {
+  return TTIImpl->hasConditionalFaultingLoadStoreForType(Ty);
+}
+
 unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
                                                       Type *Ty) const {
   return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..24dbe6efabbca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
   SDValue StoreNode =
-      DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
-                         ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+      (!IsCompressing && TTI.hasConditionalFaultingLoadStoreForType(
+                             I.getArgOperand(0)->getType()->getScalarType()))
+          ? TLI.visitMaskedStoreForCondFaulting(DAG, sdl, getMemoryRoot(), MMO,
+                                                Ptr, Src0, Mask)
+          : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+                               VT, MMO, ISD::UNINDEXED, false /* Truncating */,
+                               IsCompressing);
   DAG.setRoot(StoreNode);
   setValue(&I, StoreNode);
 }
@@ -4958,12 +4967,23 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
 
-  SDValue Load =
-      DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
-                        ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+  // The Load/Res may point to different values.
+  SDValue Load;
+  SDValue Res;
+  if (!IsExpanding && TTI.hasConditionalFaultingLoadStoreForType(
+                          Src0Operand->getType()->getScalarType()))
+    Res = TLI.visitMaskedLoadForCondFaulting(DAG, sdl, InChain, MMO, Load, Ptr,
+                                             Src0, Mask);
+  else
+    Res = Load =
+        DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+                          ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
   if (AddToChain)
     PendingLoads.push_back(Load.getValue(1));
-  setValue(&I, Load);
+  setValue(&I, Res);
 }
 
 void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..2f89758c4783d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,54 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
   return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
 }
 
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+                                      SDValue V) {
+  assert(V.getValueType() == MVT::i1 && "assume i1 value");
+  EVT Ty = MVT::i8;
+  SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+  SDValue Zero = DAG.getConstant(0, DL, Ty);
+  SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+  SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+  return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoadForCondFaulting(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+    SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+  // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+  // ->
+  // _, flags = SUB 0, mask
+  // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+  // bit_cast_to_vector<res>
+  EVT VTy = PassThru.getValueType();
+  EVT Ty = VTy.getVectorElementType();
+  SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+  SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+  NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+  return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStoreForCondFaulting(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+    SDValue Ptr, SDValue Val, SDValue Mask) const {
+  // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+  // ->
+  // _, flags = SUB 0, mask
+  // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+  EVT Ty = Val.getValueType().getVectorElementType();
+  SDVTList Tys = DAG.getVTList(MVT::Other);
+  SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+  return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
 /// Provide custom lowering hooks for some operations.
 SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -34024,6 +34072,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(STRICT_FP80_ADD)
   NODE_NAME_CASE(CCMP)
   NODE_NAME_CASE(CTEST)
+  NODE_NAME_CASE(CLOAD)
+  NODE_NAME_CASE(CSTORE)
   }
   return nullptr;
 #undef NODE_NAME_CASE
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..05ef982ef2023 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
     // is needed so that this can be expanded with control flow.
     VASTART_SAVE_XMM_REGS,
 
+    // Conditional load/store instructions
+    CLOAD,
+    CSTORE,
+
     // WARNING: Do not add anything in the end unless you want the node to
     // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
     // opcodes will be thought as target memory ops!
@@ -1556,6 +1560,18 @@ namespace llvm {
     bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
                                  unsigned OpNo) const override;
 
+    SDValue visitMaskedLoadForCondFaulting(SelectionDAG &DAG, const SDLoc &DL,
+                                           SDValue Chain,
+                                           MachineMemOperand *MMO,
+                                           SDValue &NewLoad, SDValue Ptr,
+                                           SDValue PassThru,
+                                           SDValue Mask) const override;
+    SDValue visitMaskedStoreForCondFaulting(SelectionDAG &DAG, const SDLoc &DL,
+                                            SDValue Chain,
+                                            MachineMemOperand *MMO, SDValue Ptr,
+                                            SDValue Val,
+                                            SDValue Mask) const override;
+
     /// Lower interleaved load(s) into target specific
     /// instructions/intrinsics.
     bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
             (CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
   def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
             (CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  // FIXME: Shouldn't patterns for 0 work for undef?
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+  def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
 }
 
 // SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp    : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
 def SDTX86Ccmp    : SDTypeProfile<1, 5,
                                   [SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
 
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload    : SDTypeProfile<1, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+                                   SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore    : SDTypeProfile<0, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>,
+                                   SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
 def SDTX86Cmov    : SDTypeProfile<1, 4,
                                   [SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
                                    SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt      : SDNode<"X86ISD::BT",       SDTX86CmpTest>;
 def X86ccmp    : SDNode<"X86ISD::CCMP",     SDTX86Ccmp>;
 def X86ctest   : SDNode<"X86ISD::CTEST",    SDTX86Ccmp>;
 
+def X86cload    : SDNode<"X86ISD::CLOAD",   SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore   : SDNode<"X86ISD::CSTORE",  SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 def X86cmov    : SDNode<"X86ISD::CMOV",     SDTX86Cmov>;
 def X86brcond  : SDNode<"X86ISD::BRCOND",   SDTX86BrCond,
                         [SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..1100be925b127 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,23 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 8;
 }
 
+bool X86TTIImpl::hasConditionalFaultingLoadStoreForType(Type *Ty) const {
+  // Conditional faulting is supported by CFCMOV, which only accepts
+  // 16/32/64-bit operands.
+  // NOTE: Though VMOVSS/VMOVSD suppresses memory fault with zero mask, it has
+  // performance penalty.
+  if (!ST->hasCF() || !Ty || !Ty->isIntegerTy())
+    return false;
+  switch (cast<IntegerType>(Ty)->getBitWidth()) {
+  default:
+    return false;
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  }
+}
+
 TypeSize
 X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
   unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5891,14 +5908,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+  bool IsSingleElementVector =
+      isa<VectorType>(DataTy) &&
+      cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+  Type *ScalarTy = DataTy->getScalarType();
+
+  if (ST->hasCF() && IsSingleElementVector &&
+      hasConditionalFaultingLoadStoreForType(ScalarTy))
+    return true;
+
   if (!ST->hasAVX())
     return false;
 
-  // The backend can't handle a single element vector.
-  if (isa<VectorType>(DataTy) &&
-      cast<FixedVectorType>(DataTy)->getNumElements() == 1)
+  // The backend can't handle a single element vector w/o CFCMOV.
+  if (IsSingleElementVector)
     return false;
-  Type *ScalarTy = DataTy->getScalarType();
 
   if (ScalarTy->isPointerTy())
     return true;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index e14dc9fc09051..701648c6a2b3a 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -132,6 +132,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// @{
 
   unsigned getNumberOfRegisters(unsigned ClassID) const;
+  bool hasConditionalFaultingLoadStoreForType(Type *Ty) const;
   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
   unsigned getMaxInterleaveFactor(ElementCount VF);
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 4e2dc7f2b2f4e..86bae71768dcd 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -131,6 +131,12 @@ static cl::opt<bool> HoistCondStores(
     "simplifycfg-hoist-cond-stores", cl::Hidden, cl::init(true),
     cl::desc("Hoist conditional stores if an unconditional store precedes"));
 
+static cl::opt<bool> HoistLoadsStoresWithCondFaulting(
+    "simplifycfg-hoist-loads-stores-with-cond-faulting", cl::Hidden,
+    cl::init(true),
+    cl::desc("Hoist loads/stores if the target supports "
+             "conditional faulting"));
+
 static cl::opt<bool> MergeCondStores(
     "simplifycfg-merge-cond-stores", cl::Hidden, cl::init(true),
     cl::desc("Hoist conditional stores even if an unconditional store does not "
@@ -275,6 +281,7 @@ class SimplifyCFGOpt {
   bool hoistSuccIdenticalTerminatorToSwitchOrIf(
       Instruction *TI, Instruction *I1,
       SmallVectorImpl<Instruction *> &OtherSuccTIs);
+  bool hoistLoadStoreWithCondFaultingFromSuccessors(BasicBlock *BB);
   bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB);
   bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond,
                                   BasicBlock *TrueBB, BasicBlock *FalseBB,
@@ -2960,6 +2967,177 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
   return HaveRewritablePHIs;
 }
 
+/// Hoist load/store instructions from the conditional successor blocks up into
+/// the block.
+///
+/// We are looking for code like the following:
+/// \code
+///   BB:
+///     ...
+///     %cond = icmp ult %x, %y
+///     br i1 %cond, label %TrueBB, label %FalseBB
+///   FalseBB:
+///     store i32 1, ptr %q, align 4
+///     ...
+///   TrueBB:
+///     %0 = load i32, ptr %b, align 4
+///     store i32 %0, ptr %p, align 4
+///     ...
+/// \endcode
+//
+/// We are going to transform this into:
+///
+/// \code
+///   BB:
+///     ...
+///     %cond = icmp ult %x, %y
+///     %0 = cload i32, ptr %b, %cond
+///     cstore i32 %0, ptr %p, %cond
+///     cstore i32 1, p...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95515


More information about the llvm-commits mailing list