[llvm] [RISCV] Fold extract_vector_elt of a load into the scalar load (PR #76151)
Wang Pengcheng via llvm-commits
llvm-commits at lists.llvm.org
Sun Dec 24 22:53:31 PST 2023
wangpc-pp wrote:
I mean, if this PR is a replacement of `DAGCombiner::scalarizeExtractedVectorLoad`, then we should refer to it and align the implementation.
For unaligned access, `DAGCombiner::scalarizeExtractedVectorLoad` will stop combination if target doesn't support unaligned access:
```cpp
unsigned IsFast = 0;
if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
OriginalLoad->getAddressSpace(), Alignment,
OriginalLoad->getMemOperand()->getFlags(),
&IsFast) ||
!IsFast)
return SDValue();
```
For non-constant element index, we can change its `MemoryPointerInfo`.
Here is my local git diff (but please double check it):
```diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c3f43a620f5..dd8ef309574c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14483,15 +14483,34 @@ performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
auto *LoadVec = dyn_cast<LoadSDNode>(InputVec);
EVT VecEltVT = InVecVT.getVectorElementType();
- auto *CIdx = dyn_cast<ConstantSDNode>(EltIdx);
// extract_vec_elt (load X), C --> scalar load (X+C)
- if (LoadVec && CIdx && ISD::isNormalLoad(LoadVec) && LoadVec->isSimple()) {
+ if (LoadVec && ISD::isNormalLoad(LoadVec) && LoadVec->isSimple()) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue NewPtr = TLI.getVectorElementPointer(DAG, LoadVec->getBasePtr(),
InVecVT, EltIdx);
- unsigned PtrOff = VecEltVT.getSizeInBits() * CIdx->getZExtValue() / 8;
- MachinePointerInfo MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
- Align Alignment = commonAlignment(LoadVec->getAlign(), PtrOff);
+ Align Alignment = LoadVec->getAlign();
+ MachinePointerInfo MPI;
+ if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltIdx)) {
+ int Elt = ConstEltNo->getZExtValue();
+ unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
+ MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
+ Alignment = commonAlignment(Alignment, PtrOff);
+ } else {
+ // Discard the pointer info except the address space because the memory
+ // operand can't represent this new access since the offset is variable.
+ MPI = MachinePointerInfo(LoadVec->getPointerInfo().getAddrSpace());
+ Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
+ }
+
+ // Don't perform the combination if unaligned access is not allowed.
+ unsigned IsFast = 0;
+ if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
+ VecEltVT, LoadVec->getAddressSpace(), Alignment,
+ LoadVec->getMemOperand()->getFlags(),
+ &IsFast) ||
+ !IsFast)
+ return SDValue();
+
SDValue Load =
DAG.getLoad(VecEltVT, DL, LoadVec->getChain(), NewPtr, MPI, Alignment,
LoadVec->getMemOperand()->getFlags(), LoadVec->getAAInfo());
```
For non-constant cases, the outputs are like:
```diff
define i16 @extractelt_v8i16_idx(ptr %x, i32 zeroext %idx) nounwind {
; CHECK-LABEL: extractelt_v8i16_idx:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; CHECK-NEXT: vle16.v v8, (a0)
-; CHECK-NEXT: vslidedown.vx v8, v8, a1
-; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: andi a1, a1, 7
+; CHECK-NEXT: slli a1, a1, 1
+; CHECK-NEXT: add a0, a0, a1
+; CHECK-NEXT: lh a0, 0(a0)
; CHECK-NEXT: ret
%a = load <8 x i16>, ptr %x
%b = extractelement <8 x i16> %a, i32 %idx
ret i16 %b
}
```
https://github.com/llvm/llvm-project/pull/76151
More information about the llvm-commits
mailing list