[llvm] [DAG] Improve known bits of Zext/Sext loads with range metadata (PR #80829)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 05:05:17 PST 2024


================
@@ -3612,32 +3612,42 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
           }
         }
       }
-    } else if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
-      // If this is a ZEXTLoad and we are looking at the loaded value.
-      EVT VT = LD->getMemoryVT();
-      unsigned MemBits = VT.getScalarSizeInBits();
-      Known.Zero.setBitsFrom(MemBits);
-    } else if (const MDNode *Ranges = LD->getRanges()) {
-      EVT VT = LD->getValueType(0);
-
-      // TODO: Handle for extending loads
-      if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
+    } else if (Op.getResNo() == 0) {
+      KnownBits Known0(!LD->getMemoryVT().isScalableVT()
+                           ? LD->getMemoryVT().getSizeInBits()
+                           : BitWidth);
+      EVT VT = Op.getValueType();
+      // Fill in any known bits from range information. There are 3 types being
+      // used. The results VT (same vector elt size as BitWidth), the loaded
+      // MemoryVT (which may or may not be vector) and the range VTs original
+      // type. The range matadata needs the full range (i.e
+      // MemoryVT().getSizeInBits()), which is truncated to the correct elt size
+      // if it is know. These are then extended to the original VT sizes below.
+      if (const MDNode *MD = LD->getRanges()) {
+        computeKnownBitsFromRangeMetadata(*MD, Known0);
         if (VT.isVector()) {
           // Handle truncation to the first demanded element.
           // TODO: Figure out which demanded elements are covered
           if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
             break;
+          Known0 = Known0.trunc(BitWidth);
+        }
+      }
 
-          // Handle the case where a load has a vector type, but scalar memory
-          // with an attached range.
-          EVT MemVT = LD->getMemoryVT();
-          KnownBits KnownFull(MemVT.getSizeInBits());
+      if (LD->getMemoryVT().isVector())
+        Known0 = Known0.trunc(LD->getMemoryVT().getScalarSizeInBits());
 
-          computeKnownBitsFromRangeMetadata(*Ranges, KnownFull);
-          Known = KnownFull.trunc(BitWidth);
-        } else
-          computeKnownBitsFromRangeMetadata(*Ranges, Known);
-      }
+      // Extend the Known bits from memory to the size of the result.
+      if (ISD::isZEXTLoad(Op.getNode()))
+        Known = Known0.zext(BitWidth);
+      else if (ISD::isSEXTLoad(Op.getNode()))
+        Known = Known0.sext(BitWidth);
+      else if (ISD::isEXTLoad(Op.getNode()))
+        Known = Known0.anyext(BitWidth);
----------------
arsenm wrote:

Do you need to guard against FP type ext loads? 

https://github.com/llvm/llvm-project/pull/80829


More information about the llvm-commits mailing list