[llvm] [X86] Fix miscompile in combineShiftRightArithmetic (PR #86597)

Björn Pettersson via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 01:13:40 PDT 2024


https://github.com/bjope updated https://github.com/llvm/llvm-project/pull/86597

>From ed110106acde9f0b9260bc332c881431340df0d9 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Mon, 25 Mar 2024 23:41:28 +0100
Subject: [PATCH 1/3] [X86] Pre-commit test case for bug in
 combineShiftRightArithmetic

It has been noticed that combineShiftRightArithmetic isn't dealing
properly with large shift amounts, as demonstrated by the test
case added in this commit.

I think the problem partly is related to X86 using i8 as shift amount
type during ISel. So shift amount larger then 127 may be treated
as negative shift amounts if not being careful.
---
 llvm/test/CodeGen/X86/sar_fold.ll | 44 +++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/llvm/test/CodeGen/X86/sar_fold.ll b/llvm/test/CodeGen/X86/sar_fold.ll
index 21655e19440afe..7607ca386d577e 100644
--- a/llvm/test/CodeGen/X86/sar_fold.ll
+++ b/llvm/test/CodeGen/X86/sar_fold.ll
@@ -44,3 +44,47 @@ define i32 @shl24sar25(i32 %a) #0 {
   %2 = ashr exact i32 %1, 25
   ret i32 %2
 }
+
+define void @shl144sar48(ptr %p) #0 {
+; CHECK-LABEL: shl144sar48:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; CHECK-NEXT:    movswl (%eax), %ecx
+; CHECK-NEXT:    movl %ecx, %edx
+; CHECK-NEXT:    sarl $31, %edx
+; CHECK-NEXT:    shldl $2, %ecx, %edx
+; CHECK-NEXT:    shll $2, %ecx
+; CHECK-NEXT:    movl %ecx, 12(%eax)
+; CHECK-NEXT:    movl %edx, 16(%eax)
+; CHECK-NEXT:    movl $0, 8(%eax)
+; CHECK-NEXT:    movl $0, 4(%eax)
+; CHECK-NEXT:    movl $0, (%eax)
+; CHECK-NEXT:    retl
+  %a = load i160, ptr %p
+  %1 = shl i160 %a, 144
+  %2 = ashr exact i160 %1, 46
+  store i160 %2, ptr %p
+  ret void
+}
+
+; This is incorrect. The 142 least significant bits in the stored value should
+; be zero, and but 142-157 should be taken from %a with a sign-extend into the
+; two most significant bits.
+define void @shl144sar2(ptr %p) #0 {
+; CHECK-LABEL: shl144sar2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; CHECK-NEXT:    movswl (%eax), %ecx
+; CHECK-NEXT:    sarl $31, %ecx
+; CHECK-NEXT:    movl %ecx, 16(%eax)
+; CHECK-NEXT:    movl %ecx, 8(%eax)
+; CHECK-NEXT:    movl %ecx, 12(%eax)
+; CHECK-NEXT:    movl %ecx, 4(%eax)
+; CHECK-NEXT:    movl %ecx, (%eax)
+; CHECK-NEXT:    retl
+  %a = load i160, ptr %p
+  %1 = shl i160 %a, 144
+  %2 = ashr exact i160 %1, 2
+  store i160 %2, ptr %p
+  ret void
+}

>From 9002dad4be1807c9a6d56c45fccf961bdeb6dc54 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Mon, 25 Mar 2024 23:53:10 +0100
Subject: [PATCH 2/3] [X86] Fix miscompile in combineShiftRightArithmetic

When folding (ashr (shl, x, c1), c2) we need to treat c1 and c2
as unsigned to find out if the combined shift should be a left
or right shift.
Also do an early out during pre-legalization in case c1 and c2
has differet types, as that otherwise complicated the comparison
of c1 and c2 a bit.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 11 ++++++-----
 llvm/test/CodeGen/X86/sar_fold.ll       | 13 +++++--------
 2 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9acbe17d0bcad2..7c6f6fa52d5677 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47428,6 +47428,8 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   APInt SarConst = N1->getAsAPIntVal();
   EVT CVT = N1.getValueType();
 
+  if (CVT != N01.getValueType())
+    return SDValue();
   if (SarConst.isNegative())
     return SDValue();
 
@@ -47440,14 +47442,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
     SDLoc DL(N);
     SDValue NN =
         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
-    SarConst = SarConst - (Size - ShiftSize);
-    if (SarConst == 0)
+    if (SarConst.eq(ShlConst))
       return NN;
-    if (SarConst.isNegative())
+    if (SarConst.ult(ShlConst))
       return DAG.getNode(ISD::SHL, DL, VT, NN,
-                         DAG.getConstant(-SarConst, DL, CVT));
+                         DAG.getConstant(ShlConst - SarConst, DL, CVT));
     return DAG.getNode(ISD::SRA, DL, VT, NN,
-                       DAG.getConstant(SarConst, DL, CVT));
+                       DAG.getConstant(SarConst - ShlConst, DL, CVT));
   }
   return SDValue();
 }
diff --git a/llvm/test/CodeGen/X86/sar_fold.ll b/llvm/test/CodeGen/X86/sar_fold.ll
index 7607ca386d577e..0f1396954b03a1 100644
--- a/llvm/test/CodeGen/X86/sar_fold.ll
+++ b/llvm/test/CodeGen/X86/sar_fold.ll
@@ -67,20 +67,17 @@ define void @shl144sar48(ptr %p) #0 {
   ret void
 }
 
-; This is incorrect. The 142 least significant bits in the stored value should
-; be zero, and but 142-157 should be taken from %a with a sign-extend into the
-; two most significant bits.
 define void @shl144sar2(ptr %p) #0 {
 ; CHECK-LABEL: shl144sar2:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; CHECK-NEXT:    movswl (%eax), %ecx
-; CHECK-NEXT:    sarl $31, %ecx
+; CHECK-NEXT:    shll $14, %ecx
 ; CHECK-NEXT:    movl %ecx, 16(%eax)
-; CHECK-NEXT:    movl %ecx, 8(%eax)
-; CHECK-NEXT:    movl %ecx, 12(%eax)
-; CHECK-NEXT:    movl %ecx, 4(%eax)
-; CHECK-NEXT:    movl %ecx, (%eax)
+; CHECK-NEXT:    movl $0, 8(%eax)
+; CHECK-NEXT:    movl $0, 12(%eax)
+; CHECK-NEXT:    movl $0, 4(%eax)
+; CHECK-NEXT:    movl $0, (%eax)
 ; CHECK-NEXT:    retl
   %a = load i160, ptr %p
   %1 = shl i160 %a, 144

>From f6ae388eb578d2e671979cba3fe92febefc6145b Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Tue, 26 Mar 2024 09:12:37 +0100
Subject: [PATCH 3/3] fixup, tidying up code comments and variable names

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 26 +++++++++++++------------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7c6f6fa52d5677..b8d0253b26818a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47406,10 +47406,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
     return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
   }
 
-  // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
-  // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
-  // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
-  // depending on sign of (SarConst - [56,48,32,24,16])
+  // fold (SRA (SHL X, ShlConst), SraConst)
+  // into (SHL (sext_in_reg X), ShlConst - SraConst)
+  //   or (sext_in_reg X)
+  //   or (SRA (sext_in_reg X), SraConst - ShlConst)
+  // depending on relation betwen SraConst and ShlConst.
+  // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
+  // us to do the sext_in_reg from corresponding bit.
 
   // sexts in X86 are MOVs. The MOVs have the same code size
   // as above SHIFTs (only SHIFT on 1 has lower code size).
@@ -47425,30 +47428,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   SDValue N00 = N0.getOperand(0);
   SDValue N01 = N0.getOperand(1);
   APInt ShlConst = N01->getAsAPIntVal();
-  APInt SarConst = N1->getAsAPIntVal();
+  APInt SraConst = N1->getAsAPIntVal();
   EVT CVT = N1.getValueType();
 
   if (CVT != N01.getValueType())
     return SDValue();
-  if (SarConst.isNegative())
+  if (SraConst.isNegative())
     return SDValue();
 
   for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
     unsigned ShiftSize = SVT.getSizeInBits();
-    // skipping types without corresponding sext/zext and
-    // ShlConst that is not one of [56,48,32,24,16]
+    // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
     if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
       continue;
     SDLoc DL(N);
     SDValue NN =
         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
-    if (SarConst.eq(ShlConst))
+    if (SraConst.eq(ShlConst))
       return NN;
-    if (SarConst.ult(ShlConst))
+    if (SraConst.ult(ShlConst))
       return DAG.getNode(ISD::SHL, DL, VT, NN,
-                         DAG.getConstant(ShlConst - SarConst, DL, CVT));
+                         DAG.getConstant(ShlConst - SraConst, DL, CVT));
     return DAG.getNode(ISD::SRA, DL, VT, NN,
-                       DAG.getConstant(SarConst - ShlConst, DL, CVT));
+                       DAG.getConstant(SraConst - ShlConst, DL, CVT));
   }
   return SDValue();
 }



More information about the llvm-commits mailing list