[llvm] [LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. (PR #104820)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 09:56:57 PDT 2024


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/104820

>From 8ed29b7e193ff9bcf6b531abdef8827164b18a9d Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 16 Aug 2024 17:04:43 +0100
Subject: [PATCH 1/2] [LLVM][AArch64] Enable verifyTargetSDNode for scalable
 vectors and fix the fallout.

Fix incorrect use of AArch64ISD::UZP1/UUNPK{HI,LO} in:
  AArch64TargetLowering::LowerDIV
  AArch64TargetLowering::LowerINSERT_SUBVECTOR

The latter highlighted DAG combines that relied on broken behaviour,
which this patch also fixes.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 85 ++++++++++++++-----
 1 file changed, 62 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28ad0abf25703b..7a677b9e9c97a2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14897,10 +14897,11 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
     // NOP cast operands to the largest legal vector of the same element count.
     if (VT.isFloatingPoint()) {
       Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
-      Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
+      Vec1 = getSVESafeBitCast(NarrowVT, Vec1, DAG);
     } else {
       // Legal integer vectors are already their largest so Vec0 is fine as is.
       Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
+      Vec1 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, Vec1);
     }
 
     // To replace the top/bottom half of vector V with vector SubV we widen the
@@ -14909,11 +14910,13 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
     SDValue Narrow;
     if (Idx == 0) {
       SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
+      HiVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, HiVec0);
       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
     } else {
       assert(Idx == InVT.getVectorMinNumElements() &&
              "Invalid subvector index!");
       SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
+      LoVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, LoVec0);
       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
     }
 
@@ -15013,7 +15016,9 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
   SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
   SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
   SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
-  return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
+  SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultLo);
+  SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultHi);
+  return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLoCast, ResultHiCast);
 }
 
 bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
@@ -22667,7 +22672,19 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   SDValue Rshrnb = DAG.getNode(
       AArch64ISD::RSHRNB_I, DL, ResVT,
       {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
-  return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
+  return DAG.getNode(AArch64ISD::NVCAST, DL, VT, Rshrnb);
+}
+
+static SDValue isNVCastToHalfWidthElements(SDValue V) {
+  if (V.getOpcode() != AArch64ISD::NVCAST)
+    return SDValue();
+
+  SDValue Op = V.getOperand(0);
+  if (V.getValueType().getVectorElementCount() !=
+      Op.getValueType().getVectorElementCount() * 2)
+    return SDValue();
+
+  return Op;
 }
 
 static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
@@ -22730,25 +22747,37 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
     return Urshr;
 
-  if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
-    return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+  if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
+      Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+    }
+  }
 
-  if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
-    return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+  if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
+      Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+    }
+  }
 
-  // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
-  if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
-    if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
-      SDValue X = Op0.getOperand(0).getOperand(0);
-      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+  // uzp1<ty>(nvcast(unpklo(uzp1<ty>(x, y))), z) => uzp1<ty>(x, z)
+  if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
+    if (PreCast.getOpcode() == AArch64ISD::UUNPKLO) {
+      if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+        SDValue X = PreCast.getOperand(0).getOperand(0);
+        return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+      }
     }
   }
 
-  // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
-  if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
-    if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
-      SDValue Z = Op1.getOperand(0).getOperand(1);
-      return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+  // uzp1<ty>(x, nvcast(unpkhi(uzp1<ty>(y, z)))) => uzp1<ty>(x, z)
+  if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
+    if (PreCast.getOpcode() == AArch64ISD::UUNPKHI) {
+      if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+        SDValue Z = PreCast.getOperand(0).getOperand(1);
+        return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+      }
     }
   }
 
@@ -29343,9 +29372,6 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
            VT.isInteger() && "Expected integer vectors!");
     assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
            "Expected vectors of equal size!");
-    // TODO: Enable assert once bogus creations have been fixed.
-    if (VT.isScalableVector())
-      break;
     assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
            "Expected result vector with half the lanes of its input!");
     break;
@@ -29363,12 +29389,25 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
     EVT Op1VT = N->getOperand(1).getValueType();
     assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
            "Expected vectors!");
-    // TODO: Enable assert once bogus creations have been fixed.
-    if (VT.isScalableVector())
-      break;
     assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
     break;
   }
+  case AArch64ISD::RSHRNB_I: {
+    assert(N->getNumValues() == 1 && "Expected one result!");
+    assert(N->getNumOperands() == 2 && "Expected two operand!");
+    EVT VT = N->getValueType(0);
+    EVT Op0VT = N->getOperand(0).getValueType();
+    EVT Op1VT = N->getOperand(1).getValueType();
+    assert(VT.isVector() && Op0VT.isVector() && VT.isInteger() &&
+           Op0VT.isInteger() && "Expected integer vectors!");
+    assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
+           "Expected vectors of equal size!");
+    assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
+           "Expected input vector with half the lanes of its result!");
+    assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
+           "Expected second operand to be a constant i32!");
+    break;
+  }
   }
 }
 #endif

>From 1eb24d2be1549c6352c4ed6d5304228d27bfbcb9 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 30 Aug 2024 16:56:20 +0000
Subject: [PATCH 2/2] Improve assert error messages.

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7a677b9e9c97a2..22c8993ba22ab5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29394,12 +29394,14 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
   }
   case AArch64ISD::RSHRNB_I: {
     assert(N->getNumValues() == 1 && "Expected one result!");
-    assert(N->getNumOperands() == 2 && "Expected two operand!");
+    assert(N->getNumOperands() == 2 && "Expected two operands!");
     EVT VT = N->getValueType(0);
     EVT Op0VT = N->getOperand(0).getValueType();
     EVT Op1VT = N->getOperand(1).getValueType();
-    assert(VT.isVector() && Op0VT.isVector() && VT.isInteger() &&
-           Op0VT.isInteger() && "Expected integer vectors!");
+    assert(VT.isVector() && VT.isInteger() &&
+           "Expected integer vector result type!");
+    assert(Op0VT.isVector() && Op0VT.isInteger() &&
+           "Expected first operand to be an integer vector!");
     assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
            "Expected vectors of equal size!");
     assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&



More information about the llvm-commits mailing list