[llvm] [DAGCombiner] In mergeTruncStore, make sure we aren't storing shifted in bits. (PR #90939)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 22:20:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

When looking through a right shift, we need to make sure that all of
the bits we are using from the shift come from the shift input and
not the sign or zero bits that are shifted in.
    
Fixes #<!-- -->90936.

---
Full diff: https://github.com/llvm/llvm-project/pull/90939.diff


5 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4) 
- (modified) llvm/lib/Support/RISCVISAUtils.cpp (+2-1) 
- (modified) llvm/lib/TargetParser/RISCVISAInfo.cpp (+5) 
- (added) llvm/test/CodeGen/AArch64/pr90936.ll (+19) 
- (modified) llvm/unittests/TargetParser/RISCVISAInfoTest.cpp (+9-1) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c0bbea16a64262..fc6bbc119d3c1b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8840,6 +8840,10 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
       if (ShiftAmtC % NarrowNumBits != 0)
         return SDValue();
 
+      // Make sure we aren't reading bits that are shifted in.
+      if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
+        return SDValue();
+
       Offset = ShiftAmtC / NarrowNumBits;
       WideVal = WideVal.getOperand(0);
     }
diff --git a/llvm/lib/Support/RISCVISAUtils.cpp b/llvm/lib/Support/RISCVISAUtils.cpp
index ca7518f71907b5..46efe93695074f 100644
--- a/llvm/lib/Support/RISCVISAUtils.cpp
+++ b/llvm/lib/Support/RISCVISAUtils.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/RISCVISAUtils.h"
+#include "llvm/ADT/StringExtras.h"
 #include <cassert>
 
 using namespace llvm;
@@ -35,7 +36,7 @@ enum RankFlags {
 // Get the rank for single-letter extension, lower value meaning higher
 // priority.
 static unsigned singleLetterExtensionRank(char Ext) {
-  assert(Ext >= 'a' && Ext <= 'z');
+  assert(isLower(Ext));
   switch (Ext) {
   case 'i':
     return 0;
diff --git a/llvm/lib/TargetParser/RISCVISAInfo.cpp b/llvm/lib/TargetParser/RISCVISAInfo.cpp
index c1d50afee09b08..d244326537faff 100644
--- a/llvm/lib/TargetParser/RISCVISAInfo.cpp
+++ b/llvm/lib/TargetParser/RISCVISAInfo.cpp
@@ -485,6 +485,11 @@ RISCVISAInfo::parseNormalizedArchString(StringRef Arch) {
     if (MajorVersionStr.getAsInteger(10, MajorVersion))
       return createStringError(errc::invalid_argument,
                                "failed to parse major version number");
+
+    if (ExtName[0] == 'z' && (ExtName.size() == 1 || isDigit(ExtName[1])))
+      return createStringError(errc::invalid_argument,
+                               "'z' must be followed by a letter");
+
     ISAInfo->addExtension(ExtName, {MajorVersion, MinorVersion});
   }
   ISAInfo->updateImpliedLengths();
diff --git a/llvm/test/CodeGen/AArch64/pr90936.ll b/llvm/test/CodeGen/AArch64/pr90936.ll
new file mode 100644
index 00000000000000..3dd2b18f9bf1c7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/pr90936.ll
@@ -0,0 +1,19 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
+
+define void @f(i16 %0, ptr %1) {
+; CHECK-LABEL: f:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ubfx w8, w0, #8, #6
+; CHECK-NEXT:    strb w0, [x1]
+; CHECK-NEXT:    strb w8, [x1, #1]
+; CHECK-NEXT:    ret
+  %3 = trunc i16 %0 to i8
+  %4 = trunc i16 %0 to i14
+  %new0 = lshr i14 %4, 8
+ store i8 %3, ptr %1, align 1
+  %5 = getelementptr i8, ptr %1, i64 1
+ %6 = trunc i14 %new0 to i8
+  store i8 %6, ptr %5, align 1
+  ret void
+}
diff --git a/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp b/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
index ec886bad4f67f7..eb8eab73686931 100644
--- a/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
+++ b/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
@@ -46,7 +46,7 @@ TEST(ParseNormalizedArchString, RejectsMalformedInputs) {
   }
 }
 
-TEST(ParseNormalizedArchString, OnlyVersion) {
+TEST(ParseNormalizedArchString, RejectsOnlyVersion) {
   for (StringRef Input : {"rv64i2p0_1p0", "rv32i2p0_1p0"}) {
     EXPECT_EQ(
         toString(RISCVISAInfo::parseNormalizedArchString(Input).takeError()),
@@ -54,6 +54,14 @@ TEST(ParseNormalizedArchString, OnlyVersion) {
   }
 }
 
+TEST(ParseNormalizedArchString, RejectsBadZ) {
+  for (StringRef Input : {"rv64i2p0_z1p0", "rv32i2p0_z2a1p0"}) {
+    EXPECT_EQ(
+        toString(RISCVISAInfo::parseNormalizedArchString(Input).takeError()),
+        "'z' must be followed by a letter");
+  }
+}
+
 TEST(ParseNormalizedArchString, AcceptsValidBaseISAsAndSetsXLen) {
   auto MaybeRV32I = RISCVISAInfo::parseNormalizedArchString("rv32i2p0");
   ASSERT_THAT_EXPECTED(MaybeRV32I, Succeeded());

``````````

</details>


https://github.com/llvm/llvm-project/pull/90939


More information about the llvm-commits mailing list