[llvm] [X86][CodeGen] Support hoisting load/store with conditional faulting (PR #96720)
Shengchen Kan via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 26 00:56:38 PDT 2024
https://github.com/KanRobert updated https://github.com/llvm/llvm-project/pull/96720
>From 13a5fd552effb48315a835d43a4e4c9a3ea3b0d1 Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Wed, 5 Jun 2024 15:04:27 +0800
Subject: [PATCH 1/3] [X86][CodeGen] Support hoisting load/store with
conditional faulting
1. Add TTI interface for conditional load/store.
2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
legalized in pass scalarize-masked-mem-intrin.
3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
CLOAD/CSTORE node to avoid error in
`DAGTypeLegalizer::ScalarizeVectorResult`.
4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
5. Lower CLOAD/CSTORE to CFCMOV by pattern match.
This is CodeGen part of #95515
---
.../llvm/Analysis/TargetTransformInfo.h | 8 ++
.../llvm/Analysis/TargetTransformInfoImpl.h | 1 +
llvm/include/llvm/CodeGen/TargetLowering.h | 14 +++
llvm/lib/Analysis/TargetTransformInfo.cpp | 4 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 31 +++++--
llvm/lib/Target/X86/X86ISelLowering.cpp | 83 ++++++++++++++++++
llvm/lib/Target/X86/X86ISelLowering.h | 12 +++
llvm/lib/Target/X86/X86InstrCMovSetCC.td | 29 +++++++
llvm/lib/Target/X86/X86InstrFragments.td | 12 +++
.../lib/Target/X86/X86TargetTransformInfo.cpp | 47 ++++++++--
llvm/lib/Target/X86/X86TargetTransformInfo.h | 1 +
llvm/test/CodeGen/X86/apx/cf.ll | 85 +++++++++++++++++++
12 files changed, 314 insertions(+), 13 deletions(-)
create mode 100644 llvm/test/CodeGen/X86/apx/cf.ll
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..f5c0127e1d422 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
/// \return the number of registers in the target-provided register class.
unsigned getNumberOfRegisters(unsigned ClassID) const;
+ /// \return true if the target supports load/store that enables fault
+ /// suppression of memory operands when the source condition is false.
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+
/// \return the target-provided register class ID for the provided type,
/// accounting for type promotion and other type-legalization techniques that
/// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
const Function &Fn) const = 0;
virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+ virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
virtual unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const = 0;
virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned getNumberOfRegisters(unsigned ClassID) const override {
return Impl.getNumberOfRegisters(ClassID);
}
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
+ return Impl.hasConditionalLoadStoreForType(Ty);
+ }
unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const override {
return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..49b4bd00baed4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
}
unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+ bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..9a0df8b29d752 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
const SDValue OldRHS, SDValue &Chain,
bool IsSignaling = false) const;
+ virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr,
+ SDValue PassThru, SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
+ virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue Ptr, SDValue Val,
+ SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
/// Returns a pair of (return value, chain).
/// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..0db8a4201fead 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
return TTIImpl->getNumberOfRegisters(ClassID);
}
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
+ return TTIImpl->hasConditionalLoadStoreForType(Ty);
+}
+
unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
Type *Ty) const {
return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..1f9e73ef949e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
SDValue StoreNode =
- DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
- ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+ (!IsCompressing && TTI.hasConditionalLoadStoreForType(
+ I.getArgOperand(0)->getType()->getScalarType()))
+ ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
+ Mask)
+ : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+ VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
+ IsCompressing);
DAG.setRoot(StoreNode);
setValue(&I, StoreNode);
}
@@ -4958,12 +4967,22 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
- SDValue Load =
- DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
- ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+ // The Load/Res may point to different values.
+ SDValue Load;
+ SDValue Res;
+ if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
+ Src0Operand->getType()->getScalarType()))
+ Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
+ else
+ Res = Load =
+ DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+ ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
if (AddToChain)
PendingLoads.push_back(Load.getValue(1));
- setValue(&I, Load);
+ setValue(&I, Res);
}
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..a45e18ae67a91 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
}
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue V) {
+ assert(V.getValueType() == MVT::i1 && "assume i1 value");
+ EVT Ty = MVT::i8;
+ SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+ SDValue Zero = DAG.getConstant(0, DL, Ty);
+ SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+ SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+ return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+ // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+ // ->
+ // _, flags = SUB 0, mask
+ // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+ // bit_cast_to_vector<res>
+ EVT VTy = PassThru.getValueType();
+ EVT Ty = VTy.getVectorElementType();
+ SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+ SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+ NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+ return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr,
+ SDValue Val, SDValue Mask) const {
+ // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+ // ->
+ // _, flags = SUB 0, mask
+ // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+ EVT Ty = Val.getValueType().getVectorElementType();
+ SDVTList Tys = DAG.getVTList(MVT::Other);
+ SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+ return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
/// Provide custom lowering hooks for some operations.
SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -34024,6 +34073,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(STRICT_FP80_ADD)
NODE_NAME_CASE(CCMP)
NODE_NAME_CASE(CTEST)
+ NODE_NAME_CASE(CLOAD)
+ NODE_NAME_CASE(CSTORE)
}
return nullptr;
#undef NODE_NAME_CASE
@@ -55633,6 +55684,36 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
+static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
+ // res, flags2 = sub 0, (setcc cc, flag)
+ // cload/cstore ..., cond_ne, flag2
+ // ->
+ // cload/cstore cc, flag
+ //
+ // if res has no users, where op is cload/cstore.
+ if (N->getConstantOperandVal(3) != X86::COND_NE)
+ return SDValue();
+
+ SDNode *Sub = N->getOperand(4).getNode();
+ if (Sub->getOpcode() != X86ISD::SUB)
+ return SDValue();
+
+ SDValue Op1 = Sub->getOperand(1);
+
+ if (Sub->hasAnyUseOfValue(0) || !X86::isZeroNode(Sub->getOperand(0)) ||
+ Op1.getOpcode() != X86ISD::SETCC)
+ return SDValue();
+
+
+ SmallVector<SDValue> Ops(N->op_values());
+ Ops[3] = Op1.getOperand(0);
+ Ops[4] = Op1.getOperand(1);
+
+ return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
+ cast<MemSDNode>(N)->getMemoryVT(),
+ cast<MemSDNode>(N)->getMemOperand());
+}
+
static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -57340,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
case X86ISD::ADD:
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
+ case X86ISD::CLOAD:
+ case X86ISD::CSTORE: return combineX86CloadCstore(N, DAG);
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..362daa98e1f8e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
// is needed so that this can be expanded with control flow.
VASTART_SAVE_XMM_REGS,
+ // Conditional load/store instructions
+ CLOAD,
+ CSTORE,
+
// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
// opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
unsigned OpNo) const override;
+ SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue &NewLoad,
+ SDValue Ptr, SDValue PassThru,
+ SDValue Mask) const override;
+ SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
+ SDValue Mask) const override;
+
/// Lower interleaved load(s) into target specific
/// instructions/intrinsics.
bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
(CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ // FIXME: Shouldn't patterns for 0 work for undef?
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+ def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
}
// SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
def SDTX86Ccmp : SDTypeProfile<1, 5,
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload : SDTypeProfile<1, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+ SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore : SDTypeProfile<0, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>,
+ SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
def SDTX86Cmov : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt : SDNode<"X86ISD::BT", SDTX86CmpTest>;
def X86ccmp : SDNode<"X86ISD::CCMP", SDTX86Ccmp>;
def X86ctest : SDNode<"X86ISD::CTEST", SDTX86Ccmp>;
+def X86cload : SDNode<"X86ISD::CLOAD", SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore : SDNode<"X86ISD::CSTORE", SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
def X86cmov : SDNode<"X86ISD::CMOV", SDTX86Cmov>;
def X86brcond : SDNode<"X86ISD::BRCOND", SDTX86BrCond,
[SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..aad4b9039bbb1 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
return 8;
}
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+ if (!ST->hasCF())
+ return false;
+ if (!Ty)
+ return true;
+ // Conditional faulting is supported by CFCMOV, which only accepts
+ // 16/32/64-bit operands.
+ // TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
+ // profitable.
+ if (!Ty->isIntegerTy())
+ return false;
+ switch (cast<IntegerType>(Ty)->getBitWidth()) {
+ default:
+ return false;
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ }
+}
+
TypeSize
X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5062,7 +5083,12 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
auto VT = TLI->getValueType(DL, SrcVTy);
InstructionCost Cost = 0;
- if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
+ MVT Ty = LT.second;
+ if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
+ // APX masked load/store for scalar is cheap.
+ return Cost + LT.first;
+
+ if (VT.isSimple() && Ty != VT.getSimpleVT() &&
LT.second.getVectorNumElements() == NumElem)
// Promotion requires extend/truncate for data and a shuffle for mask.
Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
@@ -5070,9 +5096,9 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
CostKind, 0, nullptr);
- else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
+ else if (LT.first * Ty.getVectorNumElements() > NumElem) {
auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
- LT.second.getVectorNumElements());
+ Ty.getVectorNumElements());
// Expanding requires fill mask with zeroes
Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
CostKind, 0, MaskTy);
@@ -5891,14 +5917,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
}
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+ bool IsSingleElementVector =
+ isa<VectorType>(DataTy) &&
+ cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+ Type *ScalarTy = DataTy->getScalarType();
+
+ if (ST->hasCF() && IsSingleElementVector &&
+ hasConditionalLoadStoreForType(ScalarTy))
+ return true;
+
if (!ST->hasAVX())
return false;
- // The backend can't handle a single element vector.
- if (isa<VectorType>(DataTy) &&
- cast<FixedVectorType>(DataTy)->getNumElements() == 1)
+ // The backend can't handle a single element vector w/o CFCMOV.
+ if (IsSingleElementVector)
return false;
- Type *ScalarTy = DataTy->getScalarType();
if (ScalarTy->isPointerTy())
return true;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index e14dc9fc09051..e6bb4720071d5 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -132,6 +132,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
/// @{
unsigned getNumberOfRegisters(unsigned ClassID) const;
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
unsigned getMaxInterleaveFactor(ElementCount VF);
diff --git a/llvm/test/CodeGen/X86/apx/cf.ll b/llvm/test/CodeGen/X86/apx/cf.ll
new file mode 100644
index 0000000000000..1669c6c04c45a
--- /dev/null
+++ b/llvm/test/CodeGen/X86/apx/cf.ll
@@ -0,0 +1,85 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64 -mattr=+cf -verify-machineinstrs | FileCheck %s
+
+define void @basic(i32 %a, ptr %b, ptr %p, ptr %q) {
+; CHECK-LABEL: basic:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: testl %edi, %edi
+; CHECK-NEXT: cfcmovel (%rsi), %eax
+; CHECK-NEXT: cfcmovel %eax, (%rdx)
+; CHECK-NEXT: movl $1, %eax
+; CHECK-NEXT: cfcmovneq %rax, (%rdx)
+; CHECK-NEXT: movw $2, %ax
+; CHECK-NEXT: cfcmovnew %ax, (%rcx)
+; CHECK-NEXT: retq
+entry:
+ %cond = icmp eq i32 %a, 0
+ %0 = bitcast i1 %cond to <1 x i1>
+ %1 = call <1 x i32> @llvm.masked.load.v1i32.p0(ptr %b, i32 4, <1 x i1> %0, <1 x i32> poison)
+ call void @llvm.masked.store.v1i32.p0(<1 x i32> %1, ptr %p, i32 4, <1 x i1> %0)
+ %2 = xor i1 %cond, true
+ %3 = bitcast i1 %2 to <1 x i1>
+ call void @llvm.masked.store.v1i64.p0(<1 x i64> <i64 1>, ptr %p, i32 8, <1 x i1> %3)
+ call void @llvm.masked.store.v1i16.p0(<1 x i16> <i16 2>, ptr %q, i32 8, <1 x i1> %3)
+ ret void
+}
+
+define i16 @cload_passthru_zero(i16 %a, ptr %b) {
+; CHECK-LABEL: cload_passthru_zero:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: testw %di, %di
+; CHECK-NEXT: cfcmovew (%rsi), %ax
+; CHECK-NEXT: retq
+entry:
+ %cond = icmp eq i16 %a, 0
+ %0 = bitcast i1 %cond to <1 x i1>
+ %1 = call <1 x i16> @llvm.masked.load.v1i16.p0(ptr %b, i32 4, <1 x i1> %0, <1 x i16> <i16 0>)
+ %2 = bitcast <1 x i16> %1 to i16
+ ret i16 %2
+}
+
+define i64 @cload_passthru_not_zero(i64 %a, ptr %b) {
+; CHECK-LABEL: cload_passthru_not_zero:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: testq %rdi, %rdi
+; CHECK-NEXT: cfcmoveq (%rsi), %rdi, %rax
+; CHECK-NEXT: retq
+entry:
+ %cond = icmp eq i64 %a, 0
+ %0 = bitcast i1 %cond to <1 x i1>
+ %va = bitcast i64 %a to <1 x i64>
+ %1 = call <1 x i64> @llvm.masked.load.v1i64.p0(ptr %b, i32 4, <1 x i1> %0, <1 x i64> %va)
+ %2 = bitcast <1 x i64> %1 to i64
+ ret i64 %2
+}
+
+;; No need to optimize the generated assembly for cond_false/cond_true b/c it
+;; should never be emitted by middle end. Add IR here just to check it's
+;; legal to feed constant mask to backend.
+define i16 @cond_false(ptr %b) {
+; CHECK-LABEL: cond_false:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xorl %eax, %eax
+; CHECK-NEXT: negb %al
+; CHECK-NEXT: cfcmovnew (%rdi), %ax
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast i1 false to <1 x i1>
+ %1 = call <1 x i16> @llvm.masked.load.v1i16.p0(ptr %b, i32 4, <1 x i1> %0, <1 x i16> <i16 0>)
+ %2 = bitcast <1 x i16> %1 to i16
+ ret i16 %2
+}
+
+define i64 @cond_true(ptr %b) {
+; CHECK-LABEL: cond_true:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: movb $1, %al
+; CHECK-NEXT: negb %al
+; CHECK-NEXT: cfcmovneq (%rdi), %rax
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast i1 true to <1 x i1>
+ %1 = call <1 x i64> @llvm.masked.load.v1i64.p0(ptr %b, i32 4, <1 x i1> %0, <1 x i64> <i64 0>)
+ %2 = bitcast <1 x i64> %1 to i64
+ ret i64 %2
+}
>From dba73bf8e83afe71313e755aadac1a6b51c6d6ac Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Wed, 26 Jun 2024 15:44:17 +0800
Subject: [PATCH 2/3] address review comments
---
.../SelectionDAG/SelectionDAGBuilder.cpp | 4 +--
llvm/lib/Target/X86/X86ISelLowering.cpp | 31 +++++++++----------
2 files changed, 16 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f9e73ef949e8..9f1a168c38241 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4788,8 +4788,8 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
const auto &TTI =
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
SDValue StoreNode =
- (!IsCompressing && TTI.hasConditionalLoadStoreForType(
- I.getArgOperand(0)->getType()->getScalarType()))
+ !IsCompressing && TTI.hasConditionalLoadStoreForType(
+ I.getArgOperand(0)->getType()->getScalarType())
? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
Mask)
: DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a45e18ae67a91..ab29af91283ef 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32309,20 +32309,20 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
}
static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
- SDValue V) {
- assert(V.getValueType() == MVT::i1 && "assume i1 value");
+ SDValue Mask) {
EVT Ty = MVT::i8;
- SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
- SDValue Zero = DAG.getConstant(0, DL, Ty);
+ auto V = DAG.getBitcast(MVT::i1, Mask);
+ auto VE = DAG.getZExtOrTrunc(V, DL, Ty);
+ auto Zero = DAG.getConstant(0, DL, Ty);
SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
- SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+ auto CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
return SDValue(CmpZero.getNode(), 1);
}
SDValue X86TargetLowering::visitMaskedLoad(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
- // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+ // @llvm.masked.load.v1*(ptr, alignment, mask, passthru)
// ->
// _, flags = SUB 0, mask
// res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
@@ -32330,10 +32330,9 @@ SDValue X86TargetLowering::visitMaskedLoad(
EVT VTy = PassThru.getValueType();
EVT Ty = VTy.getVectorElementType();
SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
- SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
- SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
- SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
- SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ auto ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+ auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
+ auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
return DAG.getBitcast(VTy, NewLoad);
@@ -32343,16 +32342,15 @@ SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain,
MachineMemOperand *MMO, SDValue Ptr,
SDValue Val, SDValue Mask) const {
- // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+ // llvm.masked.store.v1*(Src0, Ptr, alignment, Mask)
// ->
// _, flags = SUB 0, mask
// chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
EVT Ty = Val.getValueType().getVectorElementType();
SDVTList Tys = DAG.getVTList(MVT::Other);
- SDValue ScalarVal = DAG.getBitcast(Ty, Val);
- SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
- SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
- SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ auto ScalarVal = DAG.getBitcast(Ty, Val);
+ auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
+ auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
}
@@ -55704,8 +55702,7 @@ static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
Op1.getOpcode() != X86ISD::SETCC)
return SDValue();
-
- SmallVector<SDValue> Ops(N->op_values());
+ SmallVector<SDValue, 5> Ops(N->op_values());
Ops[3] = Op1.getOperand(0);
Ops[4] = Op1.getOperand(1);
>From 575e599420073d5ddf1d0fddd5ea28f1d815d8a3 Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Wed, 26 Jun 2024 15:56:02 +0800
Subject: [PATCH 3/3] address review comment
---
llvm/lib/Target/X86/X86InstrFragments.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 972b56e0f0cfe..038100b8264de 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,11 +15,11 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
def SDTX86Ccmp : SDTypeProfile<1, 5,
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
-// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+// RES = op PTR, PASSTHRU, COND, EFLAGS
def SDTX86Cload : SDTypeProfile<1, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
-// chain = CSTORE inchain, val, ptr, cond, flags
+// op VAL, PTR, COND, EFLAGS
def SDTX86Cstore : SDTypeProfile<0, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>,
SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
More information about the llvm-commits
mailing list