[llvm] dbca8a4 - [DAG] Improve known bits of Zext/Sext loads with range metadata (#80829)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 04:53:17 PST 2024


Author: David Green
Date: 2024-02-29T12:53:13Z
New Revision: dbca8a49b6dbbb79913d6a1bc1d59f4947353e96

URL: https://github.com/llvm/llvm-project/commit/dbca8a49b6dbbb79913d6a1bc1d59f4947353e96
DIFF: https://github.com/llvm/llvm-project/commit/dbca8a49b6dbbb79913d6a1bc1d59f4947353e96.diff

LOG: [DAG] Improve known bits of Zext/Sext loads with range metadata (#80829)

This extends the known bits for extending loads which have range
metadata, handling the range metadata on the original memory type,
extending that to the correct BitWidth.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/AArch64/setcc_knownbits.ll
    llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e3e7de1573c818..5b1b7c7c627723 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3645,32 +3645,42 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
           }
         }
       }
-    } else if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
-      // If this is a ZEXTLoad and we are looking at the loaded value.
-      EVT VT = LD->getMemoryVT();
-      unsigned MemBits = VT.getScalarSizeInBits();
-      Known.Zero.setBitsFrom(MemBits);
-    } else if (const MDNode *Ranges = LD->getRanges()) {
-      EVT VT = LD->getValueType(0);
-
-      // TODO: Handle for extending loads
-      if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
+    } else if (Op.getResNo() == 0) {
+      KnownBits Known0(!LD->getMemoryVT().isScalableVT()
+                           ? LD->getMemoryVT().getFixedSizeInBits()
+                           : BitWidth);
+      EVT VT = Op.getValueType();
+      // Fill in any known bits from range information. There are 3 types being
+      // used. The results VT (same vector elt size as BitWidth), the loaded
+      // MemoryVT (which may or may not be vector) and the range VTs original
+      // type. The range matadata needs the full range (i.e
+      // MemoryVT().getSizeInBits()), which is truncated to the correct elt size
+      // if it is know. These are then extended to the original VT sizes below.
+      if (const MDNode *MD = LD->getRanges()) {
+        computeKnownBitsFromRangeMetadata(*MD, Known0);
         if (VT.isVector()) {
           // Handle truncation to the first demanded element.
           // TODO: Figure out which demanded elements are covered
           if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
             break;
+          Known0 = Known0.trunc(BitWidth);
+        }
+      }
 
-          // Handle the case where a load has a vector type, but scalar memory
-          // with an attached range.
-          EVT MemVT = LD->getMemoryVT();
-          KnownBits KnownFull(MemVT.getSizeInBits());
+      if (LD->getMemoryVT().isVector())
+        Known0 = Known0.trunc(LD->getMemoryVT().getScalarSizeInBits());
 
-          computeKnownBitsFromRangeMetadata(*Ranges, KnownFull);
-          Known = KnownFull.trunc(BitWidth);
-        } else
-          computeKnownBitsFromRangeMetadata(*Ranges, Known);
-      }
+      // Extend the Known bits from memory to the size of the result.
+      if (ISD::isZEXTLoad(Op.getNode()))
+        Known = Known0.zext(BitWidth);
+      else if (ISD::isSEXTLoad(Op.getNode()))
+        Known = Known0.sext(BitWidth);
+      else if (ISD::isEXTLoad(Op.getNode()))
+        Known = Known0.anyext(BitWidth);
+      else
+        Known = Known0;
+      assert(Known.getBitWidth() == BitWidth);
+      return Known;
     }
     break;
   }

diff  --git a/llvm/test/CodeGen/AArch64/setcc_knownbits.ll b/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
index 46b714d8e5fbbe..bb9546af8bb7b6 100644
--- a/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
+++ b/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
@@ -21,9 +21,7 @@ define noundef i1 @logger(i32 noundef %logLevel, ptr %ea, ptr %pll) {
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB1_2: // %land.rhs
 ; CHECK-NEXT:    ldr x8, [x1]
-; CHECK-NEXT:    ldrb w8, [x8]
-; CHECK-NEXT:    cmp w8, #0
-; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ldrb w0, [x8]
 ; CHECK-NEXT:    ret
 entry:
   %0 = load i32, ptr %pll, align 4

diff  --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index bb8e76a2eeb8be..e0772684e3a954 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -6,11 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Support/SourceMgr.h"
@@ -728,4 +730,70 @@ TEST_F(AArch64SelectionDAGTest, ReplaceAllUsesWith) {
   EXPECT_EQ(DAG->getPCSections(New.getNode()), MD);
 }
 
+TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_known01) {
+  SDLoc Loc;
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+  auto Ptr = DAG->getConstant(0, Loc, Int64VT);
+  auto PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG->getMachineFunction(), 0);
+  AAMDNodes AA;
+  MDBuilder MDHelper(*DAG->getContext());
+  MDNode *Range = MDHelper.createRange(APInt(8, 0), APInt(8, 2));
+  MachineMemOperand *MMO = DAG->getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOLoad, 8, Align(8), AA, Range);
+
+  auto ALoad = DAG->getExtLoad(ISD::EXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  KnownBits Known = DAG->computeKnownBits(ALoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+
+  auto ZLoad = DAG->getExtLoad(ISD::ZEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(ZLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfffffffe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+
+  auto SLoad = DAG->getExtLoad(ISD::SEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(SLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfffffffe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+}
+
+TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_knownnegative) {
+  SDLoc Loc;
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+  auto Ptr = DAG->getConstant(0, Loc, Int64VT);
+  auto PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG->getMachineFunction(), 0);
+  AAMDNodes AA;
+  MDBuilder MDHelper(*DAG->getContext());
+  MDNode *Range = MDHelper.createRange(APInt(8, 0xf0), APInt(8, 0xff));
+  MachineMemOperand *MMO = DAG->getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOLoad, 8, Align(8), AA, Range);
+
+  auto ALoad = DAG->getExtLoad(ISD::EXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  KnownBits Known = DAG->computeKnownBits(ALoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0));
+  EXPECT_EQ(Known.One, APInt(32, 0xf0));
+
+  auto ZLoad = DAG->getExtLoad(ISD::ZEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(ZLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xffffff00));
+  EXPECT_EQ(Known.One, APInt(32, 0x000000f0));
+
+  auto SLoad = DAG->getExtLoad(ISD::SEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(SLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0));
+  EXPECT_EQ(Known.One, APInt(32, 0xfffffff0));
+}
+
 } // end namespace llvm


        


More information about the llvm-commits mailing list