[llvm] 702c4ad - [ISD::IndexType] Helper functions for common queries.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Sat May 7 04:02:16 PDT 2022


Author: Paul Walker
Date: 2022-05-07T11:23:42+01:00
New Revision: 702c4ade225be58118d34760d9f67e02a826e426

URL: https://github.com/llvm/llvm-project/commit/702c4ade225be58118d34760d9f67e02a826e426
DIFF: https://github.com/llvm/llvm-project/commit/702c4ade225be58118d34760d9f67e02a826e426.diff

LOG: [ISD::IndexType] Helper functions for common queries.

Add helper functions to query the signed and scaled properties
of ISD::IndexType along with functions to change them.

Remove setIndexType from MaskedGatherSDNode because it only has
one usage and typically should only be changed alongside its
index operand.

Minimise the direct use of the enum values to lay the groundwork
for more refactoring.

Differential Revision: https://reviews.llvm.org/D123347

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ea1d3170acba2..d57756b30291a 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1358,6 +1358,26 @@ enum MemIndexType {
 
 static const int LAST_MEM_INDEX_TYPE = UNSIGNED_UNSCALED + 1;
 
+inline bool isIndexTypeScaled(MemIndexType IndexType) {
+  return IndexType == SIGNED_SCALED || IndexType == UNSIGNED_SCALED;
+}
+
+inline bool isIndexTypeSigned(MemIndexType IndexType) {
+  return IndexType == SIGNED_SCALED || IndexType == SIGNED_UNSCALED;
+}
+
+inline MemIndexType getSignedIndexType(MemIndexType IndexType) {
+  return isIndexTypeScaled(IndexType) ? SIGNED_SCALED : SIGNED_UNSCALED;
+}
+
+inline MemIndexType getUnsignedIndexType(MemIndexType IndexType) {
+  return isIndexTypeScaled(IndexType) ? UNSIGNED_SCALED : UNSIGNED_UNSCALED;
+}
+
+inline MemIndexType getUnscaledIndexType(MemIndexType IndexType) {
+  return isIndexTypeSigned(IndexType) ? SIGNED_UNSCALED : UNSIGNED_UNSCALED;
+}
+
 //===--------------------------------------------------------------------===//
 /// LoadExtType enum - This enum defines the three variants of LOADEXT
 /// (load with extension).

diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 51e80ebbcf0ee..659cfa268224a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -2702,14 +2702,8 @@ class VPGatherScatterSDNode : public MemSDNode {
   ISD::MemIndexType getIndexType() const {
     return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
   }
-  bool isIndexScaled() const {
-    return (getIndexType() == ISD::SIGNED_SCALED) ||
-           (getIndexType() == ISD::UNSIGNED_SCALED);
-  }
-  bool isIndexSigned() const {
-    return (getIndexType() == ISD::SIGNED_SCALED) ||
-           (getIndexType() == ISD::SIGNED_UNSCALED);
-  }
+  bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
+  bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
 
   // In the both nodes address is Op1, mask is Op2:
   // VPGatherSDNode  (Chain, base, index, scale, mask, vlen)
@@ -2790,17 +2784,8 @@ class MaskedGatherScatterSDNode : public MemSDNode {
   ISD::MemIndexType getIndexType() const {
     return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
   }
-  void setIndexType(ISD::MemIndexType IndexType) {
-    LSBaseSDNodeBits.AddressingMode = IndexType;
-  }
-  bool isIndexScaled() const {
-    return (getIndexType() == ISD::SIGNED_SCALED) ||
-           (getIndexType() == ISD::UNSIGNED_SCALED);
-  }
-  bool isIndexSigned() const {
-    return (getIndexType() == ISD::SIGNED_SCALED) ||
-           (getIndexType() == ISD::SIGNED_UNSCALED);
-  }
+  bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
+  bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
 
   // In the both nodes address is Op1, mask is Op2:
   // MaskedGatherSDNode  (Chain, passthru, mask, base, index, scale)

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a2fd45b48af71..66effdebc47d5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10471,24 +10471,27 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
 }
 
 // Fold sext/zext of index into index type.
-bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index,
-                     bool Scaled, bool Signed, SelectionDAG &DAG) {
+bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType,
+                     SelectionDAG &DAG) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   // It's always safe to look through zero extends.
   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
     SDValue Op = Index.getOperand(0);
-    MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED);
     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
+      IndexType = ISD::getUnsignedIndexType(IndexType);
       Index = Op;
       return true;
+    } else if (ISD::isIndexTypeSigned(IndexType)) {
+      IndexType = ISD::getUnsignedIndexType(IndexType);
+      return true;
     }
   }
 
   // It's only safe to look through sign extends when Index is signed.
-  if (Index.getOpcode() == ISD::SIGN_EXTEND && Signed) {
+  if (Index.getOpcode() == ISD::SIGN_EXTEND &&
+      ISD::isIndexTypeSigned(IndexType)) {
     SDValue Op = Index.getOperand(0);
-    MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED);
     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
       Index = Op;
       return true;
@@ -10506,6 +10509,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   SDValue Scale = MSC->getScale();
   SDValue StoreVal = MSC->getValue();
   SDValue BasePtr = MSC->getBasePtr();
+  ISD::MemIndexType IndexType = MSC->getIndexType();
   SDLoc DL(N);
 
   // Zap scatters with a zero mask.
@@ -10514,17 +10518,16 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
 
   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
-    return DAG.getMaskedScatter(
-        DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
-        MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
+    return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
+                                DL, Ops, MSC->getMemOperand(), IndexType,
+                                MSC->isTruncatingStore());
   }
 
-  if (refineIndexType(MSC, Index, MSC->isIndexScaled(), MSC->isIndexSigned(),
-                      DAG)) {
+  if (refineIndexType(Index, IndexType, DAG)) {
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
-    return DAG.getMaskedScatter(
-        DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
-        MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
+    return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
+                                DL, Ops, MSC->getMemOperand(), IndexType,
+                                MSC->isTruncatingStore());
   }
 
   return SDValue();
@@ -10602,6 +10605,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   SDValue Scale = MGT->getScale();
   SDValue PassThru = MGT->getPassThru();
   SDValue BasePtr = MGT->getBasePtr();
+  ISD::MemIndexType IndexType = MGT->getIndexType();
   SDLoc DL(N);
 
   // Zap gathers with a zero mask.
@@ -10610,19 +10614,16 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
 
   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
-    return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
-                               MGT->getMemoryVT(), DL, Ops,
-                               MGT->getMemOperand(), MGT->getIndexType(),
-                               MGT->getExtensionType());
+    return DAG.getMaskedGather(
+        DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
+        Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
   }
 
-  if (refineIndexType(MGT, Index, MGT->isIndexScaled(), MGT->isIndexSigned(),
-                      DAG)) {
+  if (refineIndexType(Index, IndexType, DAG)) {
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
-    return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
-                               MGT->getMemoryVT(), DL, Ops,
-                               MGT->getMemOperand(), MGT->getIndexType(),
-                               MGT->getExtensionType());
+    return DAG.getMaskedGather(
+        DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
+        Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
   }
 
   return SDValue();

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index b209aecfbfa06..e1da053aaf5b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -8613,14 +8613,9 @@ SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op,
 ISD::MemIndexType
 TargetLowering::getCanonicalIndexType(ISD::MemIndexType IndexType, EVT MemVT,
                                       SDValue Offsets) const {
-  bool IsScaledIndex =
-      (IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::UNSIGNED_SCALED);
-  bool IsSignedIndex =
-      (IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::SIGNED_UNSCALED);
-
   // Scaling is unimportant for bytes, canonicalize to unscaled.
-  if (IsScaledIndex && MemVT.getScalarType() == MVT::i8)
-    return IsSignedIndex ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
+  if (ISD::isIndexTypeScaled(IndexType) && MemVT.getScalarType() == MVT::i8)
+    return ISD::getUnscaledIndexType(IndexType);
 
   return IndexType;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 42c44eeeb23e5..b00dd45f2bc5a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4711,10 +4711,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     return DAG.getMergeValues({Select, Load.getValue(1)}, DL);
   }
 
-  bool IsScaled =
-      IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
-  bool IsSigned =
-      IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+  bool IsScaled = MGT->isIndexScaled();
+  bool IsSigned = MGT->isIndexSigned();
 
   // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
   // must be calculated before hand.
@@ -4727,7 +4725,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
 
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
-    IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
+    IndexType = getUnscaledIndexType(IndexType);
     return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
                                MGT->getMemOperand(), IndexType, ExtType);
   }
@@ -4812,10 +4810,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   EVT MemVT = MSC->getMemoryVT();
   ISD::MemIndexType IndexType = MSC->getIndexType();
 
-  bool IsScaled =
-      IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
-  bool IsSigned =
-      IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+  bool IsScaled = MSC->isIndexScaled();
+  bool IsSigned = MSC->isIndexSigned();
 
   // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
   // must be calculated before hand.
@@ -4828,7 +4824,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
 
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
-    IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
+    IndexType = getUnscaledIndexType(IndexType);
     return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
                                 MSC->getMemOperand(), IndexType,
                                 MSC->isTruncatingStore());


        


More information about the llvm-commits mailing list