[llvm] [DAGCombiner] In mergeTruncStore, make sure we aren't storing shifted in bits. (PR #90939)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 22:20:16 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/90939

When looking through a right shift, we need to make sure that all of
the bits we are using from the shift come from the shift input and
not the sign or zero bits that are shifted in.
    
Fixes #90936.

>From 556e3009bb6044d983c0ec9aa4e594cfc6676e26 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 2 May 2024 12:13:42 -0700
Subject: [PATCH 1/3] [RISCV] Add partial validation of Z extension name to
 RISCVISAInfo::parseNormalizedArchString

If 'z' is given as the complete extension name or with a digit after it,
it will crash in the extension map compare function. Check for
these cases and give an error.
---
 llvm/lib/Support/RISCVISAUtils.cpp               |  3 ++-
 llvm/lib/TargetParser/RISCVISAInfo.cpp           |  5 +++++
 llvm/unittests/TargetParser/RISCVISAInfoTest.cpp | 10 +++++++++-
 3 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Support/RISCVISAUtils.cpp b/llvm/lib/Support/RISCVISAUtils.cpp
index ca7518f71907b5..46efe93695074f 100644
--- a/llvm/lib/Support/RISCVISAUtils.cpp
+++ b/llvm/lib/Support/RISCVISAUtils.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/RISCVISAUtils.h"
+#include "llvm/ADT/StringExtras.h"
 #include <cassert>
 
 using namespace llvm;
@@ -35,7 +36,7 @@ enum RankFlags {
 // Get the rank for single-letter extension, lower value meaning higher
 // priority.
 static unsigned singleLetterExtensionRank(char Ext) {
-  assert(Ext >= 'a' && Ext <= 'z');
+  assert(isLower(Ext));
   switch (Ext) {
   case 'i':
     return 0;
diff --git a/llvm/lib/TargetParser/RISCVISAInfo.cpp b/llvm/lib/TargetParser/RISCVISAInfo.cpp
index c1d50afee09b08..d244326537faff 100644
--- a/llvm/lib/TargetParser/RISCVISAInfo.cpp
+++ b/llvm/lib/TargetParser/RISCVISAInfo.cpp
@@ -485,6 +485,11 @@ RISCVISAInfo::parseNormalizedArchString(StringRef Arch) {
     if (MajorVersionStr.getAsInteger(10, MajorVersion))
       return createStringError(errc::invalid_argument,
                                "failed to parse major version number");
+
+    if (ExtName[0] == 'z' && (ExtName.size() == 1 || isDigit(ExtName[1])))
+      return createStringError(errc::invalid_argument,
+                               "'z' must be followed by a letter");
+
     ISAInfo->addExtension(ExtName, {MajorVersion, MinorVersion});
   }
   ISAInfo->updateImpliedLengths();
diff --git a/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp b/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
index ec886bad4f67f7..eb8eab73686931 100644
--- a/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
+++ b/llvm/unittests/TargetParser/RISCVISAInfoTest.cpp
@@ -46,7 +46,7 @@ TEST(ParseNormalizedArchString, RejectsMalformedInputs) {
   }
 }
 
-TEST(ParseNormalizedArchString, OnlyVersion) {
+TEST(ParseNormalizedArchString, RejectsOnlyVersion) {
   for (StringRef Input : {"rv64i2p0_1p0", "rv32i2p0_1p0"}) {
     EXPECT_EQ(
         toString(RISCVISAInfo::parseNormalizedArchString(Input).takeError()),
@@ -54,6 +54,14 @@ TEST(ParseNormalizedArchString, OnlyVersion) {
   }
 }
 
+TEST(ParseNormalizedArchString, RejectsBadZ) {
+  for (StringRef Input : {"rv64i2p0_z1p0", "rv32i2p0_z2a1p0"}) {
+    EXPECT_EQ(
+        toString(RISCVISAInfo::parseNormalizedArchString(Input).takeError()),
+        "'z' must be followed by a letter");
+  }
+}
+
 TEST(ParseNormalizedArchString, AcceptsValidBaseISAsAndSetsXLen) {
   auto MaybeRV32I = RISCVISAInfo::parseNormalizedArchString("rv32i2p0");
   ASSERT_THAT_EXPECTED(MaybeRV32I, Succeeded());

>From e77c987e76a6d349cf17b737fda9dc27c742a310 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 2 May 2024 22:09:59 -0700
Subject: [PATCH 2/3] [AArch64] Add test for #90936. NFC

---
 llvm/test/CodeGen/AArch64/pr90936.ll | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/pr90936.ll

diff --git a/llvm/test/CodeGen/AArch64/pr90936.ll b/llvm/test/CodeGen/AArch64/pr90936.ll
new file mode 100644
index 00000000000000..969daaf11b567d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/pr90936.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
+
+define void @f(i16 %0, ptr %1) {
+; CHECK-LABEL: f:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    strh w0, [x1]
+; CHECK-NEXT:    ret
+  %3 = trunc i16 %0 to i8
+  %4 = trunc i16 %0 to i14
+  %new0 = lshr i14 %4, 8
+ store i8 %3, ptr %1, align 1
+  %5 = getelementptr i8, ptr %1, i64 1
+ %6 = trunc i14 %new0 to i8
+  store i8 %6, ptr %5, align 1
+  ret void
+}

>From 1d83bd91c50ac0ebb0e713c9e714b1c9a59876a6 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 2 May 2024 22:16:50 -0700
Subject: [PATCH 3/3] [DAGCombiner] In mergeTruncStore, make sure we aren't
 storing shifted in bits.

When looking through a right shift, we need to make sure that all of
the bits we are using from the shift come from the shift input and
not the sign or zero bits that are shifted in.

Fixes #90936.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 4 ++++
 llvm/test/CodeGen/AArch64/pr90936.ll          | 4 +++-
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c0bbea16a64262..fc6bbc119d3c1b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8840,6 +8840,10 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
       if (ShiftAmtC % NarrowNumBits != 0)
         return SDValue();
 
+      // Make sure we aren't reading bits that are shifted in.
+      if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
+        return SDValue();
+
       Offset = ShiftAmtC / NarrowNumBits;
       WideVal = WideVal.getOperand(0);
     }
diff --git a/llvm/test/CodeGen/AArch64/pr90936.ll b/llvm/test/CodeGen/AArch64/pr90936.ll
index 969daaf11b567d..3dd2b18f9bf1c7 100644
--- a/llvm/test/CodeGen/AArch64/pr90936.ll
+++ b/llvm/test/CodeGen/AArch64/pr90936.ll
@@ -4,7 +4,9 @@
 define void @f(i16 %0, ptr %1) {
 ; CHECK-LABEL: f:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    strh w0, [x1]
+; CHECK-NEXT:    ubfx w8, w0, #8, #6
+; CHECK-NEXT:    strb w0, [x1]
+; CHECK-NEXT:    strb w8, [x1, #1]
 ; CHECK-NEXT:    ret
   %3 = trunc i16 %0 to i8
   %4 = trunc i16 %0 to i14



More information about the llvm-commits mailing list