[llvm] [X86] Fix miscompile in combineShiftRightArithmetic (PR #86597)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 16:02:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Björn Pettersson (bjope)

<details>
<summary>Changes</summary>

When folding (ashr (shl, x, c1), c2) we need to treat c1 and c2
as unsigned to find out if the combined shift should be a left
or right shift.
Also do an early out during pre-legalization in case c1 and c2
has differet types, as that otherwise complicated the comparison
of c1 and c2 a bit.

---
Full diff: https://github.com/llvm/llvm-project/pull/86597.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+6-5) 
- (modified) llvm/test/CodeGen/X86/sar_fold.ll (+44) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9acbe17d0bcad2..7c6f6fa52d5677 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47428,6 +47428,8 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   APInt SarConst = N1->getAsAPIntVal();
   EVT CVT = N1.getValueType();
 
+  if (CVT != N01.getValueType())
+    return SDValue();
   if (SarConst.isNegative())
     return SDValue();
 
@@ -47440,14 +47442,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
     SDLoc DL(N);
     SDValue NN =
         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
-    SarConst = SarConst - (Size - ShiftSize);
-    if (SarConst == 0)
+    if (SarConst.eq(ShlConst))
       return NN;
-    if (SarConst.isNegative())
+    if (SarConst.ult(ShlConst))
       return DAG.getNode(ISD::SHL, DL, VT, NN,
-                         DAG.getConstant(-SarConst, DL, CVT));
+                         DAG.getConstant(ShlConst - SarConst, DL, CVT));
     return DAG.getNode(ISD::SRA, DL, VT, NN,
-                       DAG.getConstant(SarConst, DL, CVT));
+                       DAG.getConstant(SarConst - ShlConst, DL, CVT));
   }
   return SDValue();
 }
diff --git a/llvm/test/CodeGen/X86/sar_fold.ll b/llvm/test/CodeGen/X86/sar_fold.ll
index 21655e19440afe..93810b3e717650 100644
--- a/llvm/test/CodeGen/X86/sar_fold.ll
+++ b/llvm/test/CodeGen/X86/sar_fold.ll
@@ -44,3 +44,47 @@ define i32 @shl24sar25(i32 %a) #0 {
   %2 = ashr exact i32 %1, 25
   ret i32 %2
 }
+
+define void @shl144sar48(ptr %p) #0 {
+; CHECK-LABEL: shl144sar48:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; CHECK-NEXT:    movswl (%eax), %ecx
+; CHECK-NEXT:    movl %ecx, %edx
+; CHECK-NEXT:    sarl $31, %edx
+; CHECK-NEXT:    shldl $2, %ecx, %edx
+; CHECK-NEXT:    shll $2, %ecx
+; CHECK-NEXT:    movl %ecx, 12(%eax)
+; CHECK-NEXT:    movl %edx, 16(%eax)
+; CHECK-NEXT:    movl $0, 8(%eax)
+; CHECK-NEXT:    movl $0, 4(%eax)
+; CHECK-NEXT:    movl $0, (%eax)
+; CHECK-NEXT:    retl
+  %a = load i160, ptr %p
+  %1 = shl i160 %a, 144
+  %2 = ashr exact i160 %1, 46
+  store i160 %2, ptr %p
+  ret void
+}
+
+; This is incorrect. The 142 least significant bits in the stored value should
+; be zero, and but 142-157 should be taken from %a with a sign-extend into the
+; two most significant bits.
+define void @shl144sar2(ptr %p) #0 {
+; CHECK-LABEL: shl144sar2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; CHECK-NEXT:    movswl (%eax), %ecx
+; CHECK-NEXT:    shll $14, %ecx
+; CHECK-NEXT:    movl %ecx, 16(%eax)
+; CHECK-NEXT:    movl $0, 8(%eax)
+; CHECK-NEXT:    movl $0, 12(%eax)
+; CHECK-NEXT:    movl $0, 4(%eax)
+; CHECK-NEXT:    movl $0, (%eax)
+; CHECK-NEXT:    retl
+  %a = load i160, ptr %p
+  %1 = shl i160 %a, 144
+  %2 = ashr exact i160 %1, 2
+  store i160 %2, ptr %p
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/86597


More information about the llvm-commits mailing list