[llvm] [RISCV] Fold extract_vector_elt of a load into the scalar load (PR #76151)

Wang Pengcheng via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 24 22:53:31 PST 2023


wangpc-pp wrote:

I mean, if this PR is a replacement of `DAGCombiner::scalarizeExtractedVectorLoad`, then we should refer to it and align the implementation.

For unaligned access, `DAGCombiner::scalarizeExtractedVectorLoad` will stop combination if target doesn't support unaligned access:
```cpp
  unsigned IsFast = 0;
  if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
                              OriginalLoad->getAddressSpace(), Alignment,
                              OriginalLoad->getMemOperand()->getFlags(),
                              &IsFast) ||
      !IsFast)
    return SDValue();
```

For non-constant element index, we can change its `MemoryPointerInfo`.

Here is my local git diff (but please double check it):
```diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c3f43a620f5..dd8ef309574c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14483,15 +14483,34 @@ performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
 
   auto *LoadVec = dyn_cast<LoadSDNode>(InputVec);
   EVT VecEltVT = InVecVT.getVectorElementType();
-  auto *CIdx = dyn_cast<ConstantSDNode>(EltIdx);
   // extract_vec_elt (load X), C --> scalar load (X+C)
-  if (LoadVec && CIdx && ISD::isNormalLoad(LoadVec) && LoadVec->isSimple()) {
+  if (LoadVec && ISD::isNormalLoad(LoadVec) && LoadVec->isSimple()) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     SDValue NewPtr = TLI.getVectorElementPointer(DAG, LoadVec->getBasePtr(),
                                                  InVecVT, EltIdx);
-    unsigned PtrOff = VecEltVT.getSizeInBits() * CIdx->getZExtValue() / 8;
-    MachinePointerInfo MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
-    Align Alignment = commonAlignment(LoadVec->getAlign(), PtrOff);
+    Align Alignment = LoadVec->getAlign();
+    MachinePointerInfo MPI;
+    if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltIdx)) {
+      int Elt = ConstEltNo->getZExtValue();
+      unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
+      MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
+      Alignment = commonAlignment(Alignment, PtrOff);
+    } else {
+      // Discard the pointer info except the address space because the memory
+      // operand can't represent this new access since the offset is variable.
+      MPI = MachinePointerInfo(LoadVec->getPointerInfo().getAddrSpace());
+      Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
+    }
+
+    // Don't perform the combination if unaligned access is not allowed.
+    unsigned IsFast = 0;
+    if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
+                                VecEltVT, LoadVec->getAddressSpace(), Alignment,
+                                LoadVec->getMemOperand()->getFlags(),
+                                &IsFast) ||
+        !IsFast)
+      return SDValue();
+
     SDValue Load =
         DAG.getLoad(VecEltVT, DL, LoadVec->getChain(), NewPtr, MPI, Alignment,
                     LoadVec->getMemOperand()->getFlags(), LoadVec->getAAInfo());
```

For non-constant cases, the outputs are like:
```diff
define i16 @extractelt_v8i16_idx(ptr %x, i32 zeroext %idx) nounwind {
; CHECK-LABEL: extractelt_v8i16_idx:
; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e16, m1, ta, ma
-; CHECK-NEXT:    vle16.v v8, (a0)
-; CHECK-NEXT:    vslidedown.vx v8, v8, a1
-; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    andi a1, a1, 7
+; CHECK-NEXT:    slli a1, a1, 1
+; CHECK-NEXT:    add a0, a0, a1
+; CHECK-NEXT:    lh a0, 0(a0)
; CHECK-NEXT:    ret
  %a = load <8 x i16>, ptr %x
  %b = extractelement <8 x i16> %a, i32 %idx
  ret i16 %b
}
```

https://github.com/llvm/llvm-project/pull/76151


More information about the llvm-commits mailing list