[llvm] [DAG] Improve known bits of Zext/Sext loads with range metadata (PR #80829)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 20 09:20:20 PST 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/80829

>From 5b145596f60fbea6d60388c9d30f97f57ea9043f Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 20 Feb 2024 17:20:02 +0000
Subject: [PATCH] [DAG] Improve known bits of Zext/Sext loads with range
 metadata

This extends the known bits for extending loads which have range metadata,
handling the range metadata on the original memory type, extending that to the
correct BitWidths.
---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 48 +++++++------
 llvm/test/CodeGen/AArch64/setcc_knownbits.ll  |  4 +-
 .../CodeGen/AArch64SelectionDAGTest.cpp       | 68 +++++++++++++++++++
 3 files changed, 98 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index add92cf8b31e44..a3e2b8518a46d8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3645,32 +3645,42 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
           }
         }
       }
-    } else if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
-      // If this is a ZEXTLoad and we are looking at the loaded value.
-      EVT VT = LD->getMemoryVT();
-      unsigned MemBits = VT.getScalarSizeInBits();
-      Known.Zero.setBitsFrom(MemBits);
-    } else if (const MDNode *Ranges = LD->getRanges()) {
-      EVT VT = LD->getValueType(0);
-
-      // TODO: Handle for extending loads
-      if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
+    } else if (Op.getResNo() == 0) {
+      KnownBits Known0(!LD->getMemoryVT().isScalableVT()
+                           ? LD->getMemoryVT().getSizeInBits()
+                           : BitWidth);
+      EVT VT = Op.getValueType();
+      // Fill in any known bits from range information. There are 3 types being
+      // used. The results VT (same vector elt size as BitWidth), the loaded
+      // MemoryVT (which may or may not be vector) and the range VTs original
+      // type. The range matadata needs the full range (i.e
+      // MemoryVT().getSizeInBits()), which is truncated to the correct elt size
+      // if it is know. These are then extended to the original VT sizes below.
+      if (const MDNode *MD = LD->getRanges()) {
+        computeKnownBitsFromRangeMetadata(*MD, Known0);
         if (VT.isVector()) {
           // Handle truncation to the first demanded element.
           // TODO: Figure out which demanded elements are covered
           if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
             break;
+          Known0 = Known0.trunc(BitWidth);
+        }
+      }
 
-          // Handle the case where a load has a vector type, but scalar memory
-          // with an attached range.
-          EVT MemVT = LD->getMemoryVT();
-          KnownBits KnownFull(MemVT.getSizeInBits());
+      if (LD->getMemoryVT().isVector())
+        Known0 = Known0.trunc(LD->getMemoryVT().getScalarSizeInBits());
 
-          computeKnownBitsFromRangeMetadata(*Ranges, KnownFull);
-          Known = KnownFull.trunc(BitWidth);
-        } else
-          computeKnownBitsFromRangeMetadata(*Ranges, Known);
-      }
+      // Extend the Known bits from memory to the size of the result.
+      if (ISD::isZEXTLoad(Op.getNode()))
+        Known = Known0.zext(BitWidth);
+      else if (ISD::isSEXTLoad(Op.getNode()))
+        Known = Known0.sext(BitWidth);
+      else if (ISD::isEXTLoad(Op.getNode()))
+        Known = Known0.anyext(BitWidth);
+      else
+        Known = Known0;
+      assert(Known.getBitWidth() == BitWidth);
+      return Known;
     }
     break;
   }
diff --git a/llvm/test/CodeGen/AArch64/setcc_knownbits.ll b/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
index 46b714d8e5fbbe..bb9546af8bb7b6 100644
--- a/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
+++ b/llvm/test/CodeGen/AArch64/setcc_knownbits.ll
@@ -21,9 +21,7 @@ define noundef i1 @logger(i32 noundef %logLevel, ptr %ea, ptr %pll) {
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB1_2: // %land.rhs
 ; CHECK-NEXT:    ldr x8, [x1]
-; CHECK-NEXT:    ldrb w8, [x8]
-; CHECK-NEXT:    cmp w8, #0
-; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ldrb w0, [x8]
 ; CHECK-NEXT:    ret
 entry:
   %0 = load i32, ptr %pll, align 4
diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index bb8e76a2eeb8be..e0772684e3a954 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -6,11 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Support/SourceMgr.h"
@@ -728,4 +730,70 @@ TEST_F(AArch64SelectionDAGTest, ReplaceAllUsesWith) {
   EXPECT_EQ(DAG->getPCSections(New.getNode()), MD);
 }
 
+TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_known01) {
+  SDLoc Loc;
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+  auto Ptr = DAG->getConstant(0, Loc, Int64VT);
+  auto PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG->getMachineFunction(), 0);
+  AAMDNodes AA;
+  MDBuilder MDHelper(*DAG->getContext());
+  MDNode *Range = MDHelper.createRange(APInt(8, 0), APInt(8, 2));
+  MachineMemOperand *MMO = DAG->getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOLoad, 8, Align(8), AA, Range);
+
+  auto ALoad = DAG->getExtLoad(ISD::EXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  KnownBits Known = DAG->computeKnownBits(ALoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+
+  auto ZLoad = DAG->getExtLoad(ISD::ZEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(ZLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfffffffe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+
+  auto SLoad = DAG->getExtLoad(ISD::SEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(SLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xfffffffe));
+  EXPECT_EQ(Known.One, APInt(32, 0));
+}
+
+TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_knownnegative) {
+  SDLoc Loc;
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+  auto Ptr = DAG->getConstant(0, Loc, Int64VT);
+  auto PtrInfo =
+      MachinePointerInfo::getFixedStack(DAG->getMachineFunction(), 0);
+  AAMDNodes AA;
+  MDBuilder MDHelper(*DAG->getContext());
+  MDNode *Range = MDHelper.createRange(APInt(8, 0xf0), APInt(8, 0xff));
+  MachineMemOperand *MMO = DAG->getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOLoad, 8, Align(8), AA, Range);
+
+  auto ALoad = DAG->getExtLoad(ISD::EXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  KnownBits Known = DAG->computeKnownBits(ALoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0));
+  EXPECT_EQ(Known.One, APInt(32, 0xf0));
+
+  auto ZLoad = DAG->getExtLoad(ISD::ZEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(ZLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0xffffff00));
+  EXPECT_EQ(Known.One, APInt(32, 0x000000f0));
+
+  auto SLoad = DAG->getExtLoad(ISD::SEXTLOAD, Loc, Int32VT, DAG->getEntryNode(),
+                               Ptr, Int8VT, MMO);
+  Known = DAG->computeKnownBits(SLoad);
+  EXPECT_EQ(Known.Zero, APInt(32, 0));
+  EXPECT_EQ(Known.One, APInt(32, 0xfffffff0));
+}
+
 } // end namespace llvm



More information about the llvm-commits mailing list