[llvm] [X86][CodeGen] Support hoisting load/store with conditional faulting (PR #96720)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 00:19:10 PDT 2024


================
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
   return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
 }
 
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+                                      SDValue V) {
+  assert(V.getValueType() == MVT::i1 && "assume i1 value");
+  EVT Ty = MVT::i8;
+  SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+  SDValue Zero = DAG.getConstant(0, DL, Ty);
+  SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+  SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+  return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+    SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+  // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+  // ->
+  // _, flags = SUB 0, mask
+  // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+  // bit_cast_to_vector<res>
+  EVT VTy = PassThru.getValueType();
+  EVT Ty = VTy.getVectorElementType();
+  SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+  SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
----------------
phoebewang wrote:

You can put `ScalarMask` into `getFlagsOfCmpZeroFori1` and remove the assert.

https://github.com/llvm/llvm-project/pull/96720


More information about the llvm-commits mailing list