[llvm] c507848 - [LoongArch] Optimize extractelement containing variable index for lasx (#151475)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 3 18:27:48 PDT 2025


Author: ZhaoQi
Date: 2025-09-04T09:27:44+08:00
New Revision: c5078484ff8cf35c369832d903d363c3019ef3e1

URL: https://github.com/llvm/llvm-project/commit/c5078484ff8cf35c369832d903d363c3019ef3e1
DIFF: https://github.com/llvm/llvm-project/commit/c5078484ff8cf35c369832d903d363c3019ef3e1.diff

LOG: [LoongArch] Optimize extractelement containing variable index for lasx (#151475)

Ideas suggested by: @heiher @tangaac

Added: 
    

Modified: 
    llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
    llvm/lib/Target/LoongArch/LoongArchISelLowering.h
    llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
    llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index c2398d8b58a03..1106be9146f88 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -423,6 +423,11 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::BITCAST);
   }
 
+  // Set DAG combine for 'LASX' feature.
+
+  if (Subtarget.hasExtLASX())
+    setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
+
   // Compute derived properties from the register classes.
   computeRegisterProperties(Subtarget.getRegisterInfo());
 
@@ -2778,14 +2783,58 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op,
 SDValue
 LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
                                                  SelectionDAG &DAG) const {
-  EVT VecTy = Op->getOperand(0)->getValueType(0);
+  MVT EltVT = Op.getSimpleValueType();
+  SDValue Vec = Op->getOperand(0);
+  EVT VecTy = Vec->getValueType(0);
   SDValue Idx = Op->getOperand(1);
-  unsigned NumElts = VecTy.getVectorNumElements();
+  SDLoc DL(Op);
+  MVT GRLenVT = Subtarget.getGRLenVT();
+
+  assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type");
 
-  if (isa<ConstantSDNode>(Idx) && Idx->getAsZExtVal() < NumElts)
+  if (isa<ConstantSDNode>(Idx))
     return Op;
 
-  return SDValue();
+  switch (VecTy.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected type");
+  case MVT::v32i8:
+  case MVT::v16i16:
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    // Extract the high half subvector and place it to the low half of a new
+    // vector. It doesn't matter what the high half of the new vector is.
+    EVT HalfTy = VecTy.getHalfNumVectorElementsVT(*DAG.getContext());
+    SDValue VecHi =
+        DAG.getExtractSubvector(DL, HalfTy, Vec, HalfTy.getVectorNumElements());
+    SDValue TmpVec =
+        DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecTy, DAG.getUNDEF(VecTy),
+                    VecHi, DAG.getConstant(0, DL, GRLenVT));
+
+    // Shuffle the origin Vec and the TmpVec using MaskVec, the lowest element
+    // of MaskVec is Idx, the rest do not matter. ResVec[0] will hold the
+    // desired element.
+    SDValue IdxCp =
+        DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64, DL, MVT::f32, Idx);
+    SDValue IdxVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f32, IdxCp);
+    SDValue MaskVec =
+        DAG.getBitcast((VecTy == MVT::v4f64) ? MVT::v4i64 : VecTy, IdxVec);
+    SDValue ResVec =
+        DAG.getNode(LoongArchISD::VSHUF, DL, VecTy, MaskVec, TmpVec, Vec);
+
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ResVec,
+                       DAG.getConstant(0, DL, GRLenVT));
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx);
+    SDValue SplatValue =
+        DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx);
+
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue,
+                       DAG.getConstant(0, DL, GRLenVT));
+  }
+  }
 }
 
 SDValue
@@ -6152,6 +6201,42 @@ performSPLIT_PAIR_F64Combine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue
+performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const LoongArchSubtarget &Subtarget) {
+  if (!DCI.isBeforeLegalize())
+    return SDValue();
+
+  MVT EltVT = N->getSimpleValueType(0);
+  SDValue Vec = N->getOperand(0);
+  EVT VecTy = Vec->getValueType(0);
+  SDValue Idx = N->getOperand(1);
+  unsigned IdxOp = Idx.getOpcode();
+  SDLoc DL(N);
+
+  if (!VecTy.is256BitVector() || isa<ConstantSDNode>(Idx))
+    return SDValue();
+
+  // Combine:
+  //   t2 = truncate t1
+  //   t3 = {zero/sign/any}_extend t2
+  //   t4 = extract_vector_elt t0, t3
+  // to:
+  //   t4 = extract_vector_elt t0, t1
+  if (IdxOp == ISD::ZERO_EXTEND || IdxOp == ISD::SIGN_EXTEND ||
+      IdxOp == ISD::ANY_EXTEND) {
+    SDValue IdxOrig = Idx.getOperand(0);
+    if (!(IdxOrig.getOpcode() == ISD::TRUNCATE))
+      return SDValue();
+
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
+                       IdxOrig.getOperand(0));
+  }
+
+  return SDValue();
+}
+
 SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
                                                    DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -6185,6 +6270,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
     return performVMSKLTZCombine(N, DAG, DCI, Subtarget);
   case LoongArchISD::SPLIT_PAIR_F64:
     return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget);
+  case ISD::EXTRACT_VECTOR_ELT:
+    return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget);
   }
   return SDValue();
 }
@@ -6967,6 +7054,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
     NODE_NAME_CASE(VREPLVEI)
     NODE_NAME_CASE(VREPLGR2VR)
     NODE_NAME_CASE(XVPERMI)
+    NODE_NAME_CASE(XVPERM)
     NODE_NAME_CASE(VPICK_SEXT_ELT)
     NODE_NAME_CASE(VPICK_ZEXT_ELT)
     NODE_NAME_CASE(VREPLVE)

diff  --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 316fbe66efbdd..9d14934a9d363 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -145,6 +145,7 @@ enum NodeType : unsigned {
   VREPLVEI,
   VREPLGR2VR,
   XVPERMI,
+  XVPERM,
 
   // Extended vector element extraction
   VPICK_SEXT_ELT,

diff  --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index cf63750461edd..a79c01cbe577a 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -10,8 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>,
+                                               SDTCisVec<2>, SDTCisInt<2>]>;
+
 // Target nodes.
 def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
+def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>;
 def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
 def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
 def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
@@ -1866,6 +1870,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8),
 def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8),
           (XVPERMI_D v4f64:$xj, immZExt8: $ui8)>;
 
+// XVPERM_W
+def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk),
+          (XVPERM_W v8i32:$xj, v8i32:$xk)>;
+def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk),
+          (XVPERM_W v8f32:$xj, v8i32:$xk)>;
+
 // XVREPLVE0_{W/D}
 def : Pat<(lasxsplatf32 FPR32:$fj),
           (XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>;

diff  --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
index 2e1618748688a..dddee35fb9e78 100644
--- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
+++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll
@@ -76,21 +76,11 @@ define void @extract_4xdouble(ptr %src, ptr %dst) nounwind {
 define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_32xi8_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 0
-; CHECK-NEXT:    ld.b $a0, $a0, 0
-; CHECK-NEXT:    st.b $a0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvpermi.q $xr1, $xr0, 1
+; CHECK-NEXT:    movgr2fr.w $fa2, $a2
+; CHECK-NEXT:    xvshuf.b $xr0, $xr1, $xr0, $xr2
+; CHECK-NEXT:    xvstelm.b $xr0, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <32 x i8>, ptr %src
   %e = extractelement <32 x i8> %v, i32 %idx
@@ -101,21 +91,11 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_16xi16_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 1
-; CHECK-NEXT:    ld.h $a0, $a0, 0
-; CHECK-NEXT:    st.h $a0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvpermi.q $xr1, $xr0, 1
+; CHECK-NEXT:    movgr2fr.w $fa2, $a2
+; CHECK-NEXT:    xvshuf.h $xr2, $xr1, $xr0
+; CHECK-NEXT:    xvstelm.h $xr2, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <16 x i16>, ptr %src
   %e = extractelement <16 x i16> %v, i32 %idx
@@ -126,21 +106,10 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_8xi32_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 2
-; CHECK-NEXT:    ld.w $a0, $a0, 0
-; CHECK-NEXT:    st.w $a0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvreplgr2vr.w $xr1, $a2
+; CHECK-NEXT:    xvperm.w $xr0, $xr0, $xr1
+; CHECK-NEXT:    xvstelm.w $xr0, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <8 x i32>, ptr %src
   %e = extractelement <8 x i32> %v, i32 %idx
@@ -151,21 +120,11 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_4xi64_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 3
-; CHECK-NEXT:    ld.d $a0, $a0, 0
-; CHECK-NEXT:    st.d $a0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvpermi.q $xr1, $xr0, 1
+; CHECK-NEXT:    movgr2fr.w $fa2, $a2
+; CHECK-NEXT:    xvshuf.d $xr2, $xr1, $xr0
+; CHECK-NEXT:    xvstelm.d $xr2, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <4 x i64>, ptr %src
   %e = extractelement <4 x i64> %v, i32 %idx
@@ -176,21 +135,10 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_8xfloat_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 2
-; CHECK-NEXT:    fld.s $fa0, $a0, 0
-; CHECK-NEXT:    fst.s $fa0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvreplgr2vr.w $xr1, $a2
+; CHECK-NEXT:    xvperm.w $xr0, $xr0, $xr1
+; CHECK-NEXT:    xvstelm.w $xr0, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <8 x float>, ptr %src
   %e = extractelement <8 x float> %v, i32 %idx
@@ -201,21 +149,11 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
 ; CHECK-LABEL: extract_4xdouble_idx:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi.d $sp, $sp, -96
-; CHECK-NEXT:    st.d $ra, $sp, 88 # 8-byte Folded Spill
-; CHECK-NEXT:    st.d $fp, $sp, 80 # 8-byte Folded Spill
-; CHECK-NEXT:    addi.d $fp, $sp, 96
-; CHECK-NEXT:    bstrins.d $sp, $zero, 4, 0
 ; CHECK-NEXT:    xvld $xr0, $a0, 0
-; CHECK-NEXT:    xvst $xr0, $sp, 32
-; CHECK-NEXT:    addi.d $a0, $sp, 32
-; CHECK-NEXT:    bstrins.d $a0, $a2, 4, 3
-; CHECK-NEXT:    fld.d $fa0, $a0, 0
-; CHECK-NEXT:    fst.d $fa0, $a1, 0
-; CHECK-NEXT:    addi.d $sp, $fp, -96
-; CHECK-NEXT:    ld.d $fp, $sp, 80 # 8-byte Folded Reload
-; CHECK-NEXT:    ld.d $ra, $sp, 88 # 8-byte Folded Reload
-; CHECK-NEXT:    addi.d $sp, $sp, 96
+; CHECK-NEXT:    xvpermi.q $xr1, $xr0, 1
+; CHECK-NEXT:    movgr2fr.w $fa2, $a2
+; CHECK-NEXT:    xvshuf.d $xr2, $xr1, $xr0
+; CHECK-NEXT:    xvstelm.d $xr2, $a1, 0, 0
 ; CHECK-NEXT:    ret
   %v = load volatile <4 x double>, ptr %src
   %e = extractelement <4 x double> %v, i32 %idx


        


More information about the llvm-commits mailing list