[llvm] [AArch64] Avoid GPR trip when moving truncated i32 vector elements (PR #114541)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 08:40:22 PST 2024


================
@@ -20735,17 +20735,44 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performTruncateCombine(SDNode *N,
-                                      SelectionDAG &DAG) {
+static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
+                                      TargetLowering::DAGCombinerInfo &DCI) {
+  SDLoc DL(N);
   EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   if (VT.isFixedLengthVector() && VT.is64BitVector() && N0.hasOneUse() &&
       N0.getOpcode() == AArch64ISD::DUP) {
     SDValue Op = N0.getOperand(0);
     if (VT.getScalarType() == MVT::i32 &&
         N0.getOperand(0).getValueType().getScalarType() == MVT::i64)
-      Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i32, Op);
-    return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Op);
+      Op = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Op);
+    return DAG.getNode(N0.getOpcode(), DL, VT, Op);
+  }
+
+  // Performing the following combine produces a preferable form for ISEL.
+  // i32 (trunc (extract Vi64, idx)) -> i32 (extract (nvcast Vi32), idx*2))
+  if (DCI.isAfterLegalizeDAG() && N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+    SDValue Op = N0.getOperand(0);
+    SDValue ExtractIndexNode = N0.getOperand(1);
+    if (!isa<ConstantSDNode>(ExtractIndexNode))
+      return SDValue();
+
+    // For a legal DAG, EXTRACT_VECTOR_ELT can only have produced an i32 or i64.
+    // So we can only expect: i32 (trunc (i64 (extract Vi64, idx))).
+    assert((VT == MVT::i32 && N0.getValueType() == MVT::i64) &&
+           "Unexpected legalisation result!");
+
+    MVT CastVT;
+    EVT SrcVectorType = Op.getValueType();
+    assert(SrcVectorType.getScalarType() == MVT::i64);
+    unsigned ExtractIndex =
+        cast<ConstantSDNode>(ExtractIndexNode)->getZExtValue();
+
+    CastVT = SrcVectorType.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
+
+    Op = DAG.getNode(AArch64ISD::NVCAST, DL, CastVT, Op);
----------------
SpencerAbson wrote:

Ah - another assumption I've made here is that the vector we are extracting from will always be 128-bit as the DAG has been legalised - https://github.com/llvm/llvm-project/blob/0d76f393eb1ad65df2f602d97fba8f6ec2ab7aa8/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L15098.

Perhaps this is worth an assertion too?

https://github.com/llvm/llvm-project/pull/114541


More information about the llvm-commits mailing list