[llvm] bba83e2 - [AArch64] LowerMUL - use SDValue directly instead of SDNode. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 22 02:45:22 PDT 2023


Author: Simon Pilgrim
Date: 2023-09-22T10:43:20+01:00
New Revision: bba83e20deec28572dbec4288cd0a135b4386c25

URL: https://github.com/llvm/llvm-project/commit/bba83e20deec28572dbec4288cd0a135b4386c25
DIFF: https://github.com/llvm/llvm-project/commit/bba83e20deec28572dbec4288cd0a135b4386c25.diff

LOG: [AArch64] LowerMUL - use SDValue directly instead of SDNode. NFC.

As discussed on D159537, using the SDValue operands directly instead of peeking inside to the SDNode prevents any issues where a non-zero result index has been used.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4220c4ee106fe05..cd1f2401763589b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4399,11 +4399,11 @@ getConstantLaneNumOfExtractHalfOperand(SDValue &Op) {
   return C->getZExtValue();
 }
 
-static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG,
+static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
                                    bool isSigned) {
-  EVT VT = N->getValueType(0);
+  EVT VT = N.getValueType();
 
-  if (N->getOpcode() != ISD::BUILD_VECTOR)
+  if (N.getOpcode() != ISD::BUILD_VECTOR)
     return false;
 
   for (const SDValue &Elt : N->op_values()) {
@@ -4425,22 +4425,22 @@ static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG,
   return true;
 }
 
-static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) {
-  if (ISD::isExtOpcode(N->getOpcode()))
-    return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG,
-                                             N->getOperand(0)->getValueType(0),
-                                             N->getValueType(0),
-                                             N->getOpcode());
+static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
+  if (ISD::isExtOpcode(N.getOpcode()))
+    return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
+                                             N.getOperand(0).getValueType(),
+                                             N.getValueType(),
+                                             N.getOpcode());
 
-  assert(N->getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
-  EVT VT = N->getValueType(0);
+  assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
+  EVT VT = N.getValueType();
   SDLoc dl(N);
   unsigned EltSize = VT.getScalarSizeInBits() / 2;
   unsigned NumElts = VT.getVectorNumElements();
   MVT TruncVT = MVT::getIntegerVT(EltSize);
   SmallVector<SDValue, 8> Ops;
   for (unsigned i = 0; i != NumElts; ++i) {
-    ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(i));
+    ConstantSDNode *C = cast<ConstantSDNode>(N.getOperand(i));
     const APInt &CInt = C->getAPIntValue();
     // Element types smaller than 32 bits are not legal, so use i32 elements.
     // The values are implicitly truncated so sext vs. zext doesn't matter.
@@ -4449,34 +4449,34 @@ static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) {
   return DAG.getBuildVector(MVT::getVectorVT(TruncVT, NumElts), dl, Ops);
 }
 
-static bool isSignExtended(SDNode *N, SelectionDAG &DAG) {
-  return N->getOpcode() == ISD::SIGN_EXTEND ||
-         N->getOpcode() == ISD::ANY_EXTEND ||
+static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
+  return N.getOpcode() == ISD::SIGN_EXTEND ||
+         N.getOpcode() == ISD::ANY_EXTEND ||
          isExtendedBUILD_VECTOR(N, DAG, true);
 }
 
-static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) {
-  return N->getOpcode() == ISD::ZERO_EXTEND ||
-         N->getOpcode() == ISD::ANY_EXTEND ||
+static bool isZeroExtended(SDValue N, SelectionDAG &DAG) {
+  return N.getOpcode() == ISD::ZERO_EXTEND ||
+         N.getOpcode() == ISD::ANY_EXTEND ||
          isExtendedBUILD_VECTOR(N, DAG, false);
 }
 
-static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) {
-  unsigned Opcode = N->getOpcode();
+static bool isAddSubSExt(SDValue N, SelectionDAG &DAG) {
+  unsigned Opcode = N.getOpcode();
   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
-    SDNode *N0 = N->getOperand(0).getNode();
-    SDNode *N1 = N->getOperand(1).getNode();
+    SDValue N0 = N.getOperand(0);
+    SDValue N1 = N.getOperand(1);
     return N0->hasOneUse() && N1->hasOneUse() &&
       isSignExtended(N0, DAG) && isSignExtended(N1, DAG);
   }
   return false;
 }
 
-static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) {
-  unsigned Opcode = N->getOpcode();
+static bool isAddSubZExt(SDValue N, SelectionDAG &DAG) {
+  unsigned Opcode = N.getOpcode();
   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
-    SDNode *N0 = N->getOperand(0).getNode();
-    SDNode *N1 = N->getOperand(1).getNode();
+    SDValue N0 = N.getOperand(0);
+    SDValue N1 = N.getOperand(1);
     return N0->hasOneUse() && N1->hasOneUse() &&
       isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG);
   }
@@ -4550,7 +4550,7 @@ SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op,
   return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
 }
 
-static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG,
+static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
                                  SDLoc DL, bool &IsMLA) {
   bool IsN0SExt = isSignExtended(N0, DAG);
   bool IsN1SExt = isSignExtended(N1, DAG);
@@ -4569,12 +4569,12 @@ static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG,
       !isExtendedBUILD_VECTOR(N1, DAG, false)) {
     SDValue ZextOperand;
     if (IsN0ZExt)
-      ZextOperand = N0->getOperand(0);
+      ZextOperand = N0.getOperand(0);
     else
-      ZextOperand = N1->getOperand(0);
+      ZextOperand = N1.getOperand(0);
     if (DAG.SignBitIsZero(ZextOperand)) {
-      SDNode *NewSext =
-          DAG.getSExtOrTrunc(ZextOperand, DL, N0->getValueType(0)).getNode();
+      SDValue NewSext =
+          DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
       if (IsN0ZExt)
         N0 = NewSext;
       else
@@ -4585,10 +4585,10 @@ static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG,
 
   // Select UMULL if we can replace the other operand with an extend.
   if (IsN0ZExt || IsN1ZExt) {
-    EVT VT = N0->getValueType(0);
+    EVT VT = N0.getValueType();
     APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
                                        VT.getScalarSizeInBits() / 2);
-    if (DAG.MaskedValueIsZero(SDValue(IsN0ZExt ? N1 : N0, 0), Mask)) {
+    if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) {
       EVT HalfVT;
       switch (VT.getSimpleVT().SimpleTy) {
       case MVT::v2i64:
@@ -4604,13 +4604,13 @@ static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG,
         return 0;
       }
       // Truncate and then extend the result.
-      SDValue NewExt = DAG.getNode(ISD::TRUNCATE, DL, HalfVT,
-                                   SDValue(IsN0ZExt ? N1 : N0, 0));
+      SDValue NewExt =
+          DAG.getNode(ISD::TRUNCATE, DL, HalfVT, IsN0ZExt ? N1 : N0);
       NewExt = DAG.getZExtOrTrunc(NewExt, DL, VT);
       if (IsN0ZExt)
-        N1 = NewExt.getNode();
+        N1 = NewExt;
       else
-        N0 = NewExt.getNode();
+        N0 = NewExt;
       return AArch64ISD::UMULL;
     }
   }
@@ -4647,18 +4647,18 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
   // that VMULL can be detected.  Otherwise v2i64 multiplications are not legal.
   assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() &&
          "unexpected type for custom-lowering ISD::MUL");
-  SDNode *N0 = Op.getOperand(0).getNode();
-  SDNode *N1 = Op.getOperand(1).getNode();
+  SDValue N0 = Op.getOperand(0);
+  SDValue N1 = Op.getOperand(1);
   bool isMLA = false;
   EVT OVT = VT;
   if (VT.is64BitVector()) {
-    if (N0->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-        isNullConstant(N0->getOperand(1)) &&
-        N1->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-        isNullConstant(N1->getOperand(1))) {
-      N0 = N0->getOperand(0).getNode();
-      N1 = N1->getOperand(0).getNode();
-      VT = N0->getValueType(0);
+    if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        isNullConstant(N0.getOperand(1)) &&
+        N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        isNullConstant(N1.getOperand(1))) {
+      N0 = N0.getOperand(0);
+      N1 = N1.getOperand(0);
+      VT = N0.getValueType();
     } else {
       if (VT == MVT::v1i64) {
         if (Subtarget->hasSVE())
@@ -4702,12 +4702,12 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
   // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
   // isel lowering to take advantage of no-stall back to back s/umul + s/umla.
   // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57
-  SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG);
-  SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG);
+  SDValue N00 = skipExtensionForVectorMULL(N0.getOperand(0), DAG);
+  SDValue N01 = skipExtensionForVectorMULL(N0.getOperand(1), DAG);
   EVT Op1VT = Op1.getValueType();
   return DAG.getNode(
       ISD::EXTRACT_SUBVECTOR, DL, OVT,
-      DAG.getNode(N0->getOpcode(), DL, VT,
+      DAG.getNode(N0.getOpcode(), DL, VT,
                   DAG.getNode(NewOpc, DL, VT,
                               DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
                   DAG.getNode(NewOpc, DL, VT,
@@ -16476,8 +16476,8 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   if (TrailingZeroes) {
     // Conservatively do not lower to shift+add+shift if the mul might be
     // folded into smul or umul.
-    if (N0->hasOneUse() && (isSignExtended(N0.getNode(), DAG) ||
-                            isZeroExtended(N0.getNode(), DAG)))
+    if (N0->hasOneUse() && (isSignExtended(N0, DAG) ||
+                            isZeroExtended(N0, DAG)))
       return SDValue();
     // Conservatively do not lower to shift+add+shift if the mul might be
     // folded into madd or msub.


        


More information about the llvm-commits mailing list