[llvm] 3e6e54e - [X86] Fix miscompile in combineShiftRightArithmetic (#86597)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 12:53:37 PDT 2024


Author: Björn Pettersson
Date: 2024-03-26T20:53:34+01:00
New Revision: 3e6e54eb795ce7a1ccd47df8c22fc08125a88886

URL: https://github.com/llvm/llvm-project/commit/3e6e54eb795ce7a1ccd47df8c22fc08125a88886
DIFF: https://github.com/llvm/llvm-project/commit/3e6e54eb795ce7a1ccd47df8c22fc08125a88886.diff

LOG: [X86] Fix miscompile in combineShiftRightArithmetic (#86597)

When folding (ashr (shl, x, c1), c2) we need to treat c1 and c2
as unsigned to find out if the combined shift should be a left
or right shift.
Also do an early out during pre-legalization in case c1 and c2
has differet types, as that otherwise complicated the comparison
of c1 and c2 a bit.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/sar_fold.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 1ce742a7c19a74..9bad38f97c6a29 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47406,10 +47406,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
     return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
   }
 
-  // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
-  // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
-  // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
-  // depending on sign of (SarConst - [56,48,32,24,16])
+  // fold (SRA (SHL X, ShlConst), SraConst)
+  // into (SHL (sext_in_reg X), ShlConst - SraConst)
+  //   or (sext_in_reg X)
+  //   or (SRA (sext_in_reg X), SraConst - ShlConst)
+  // depending on relation between SraConst and ShlConst.
+  // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
+  // us to do the sext_in_reg from corresponding bit.
 
   // sexts in X86 are MOVs. The MOVs have the same code size
   // as above SHIFTs (only SHIFT on 1 has lower code size).
@@ -47425,29 +47428,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   SDValue N00 = N0.getOperand(0);
   SDValue N01 = N0.getOperand(1);
   APInt ShlConst = N01->getAsAPIntVal();
-  APInt SarConst = N1->getAsAPIntVal();
+  APInt SraConst = N1->getAsAPIntVal();
   EVT CVT = N1.getValueType();
 
-  if (SarConst.isNegative())
+  if (CVT != N01.getValueType())
+    return SDValue();
+  if (SraConst.isNegative())
     return SDValue();
 
   for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
     unsigned ShiftSize = SVT.getSizeInBits();
-    // skipping types without corresponding sext/zext and
-    // ShlConst that is not one of [56,48,32,24,16]
+    // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
     if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
       continue;
     SDLoc DL(N);
     SDValue NN =
         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
-    SarConst = SarConst - (Size - ShiftSize);
-    if (SarConst == 0)
+    if (SraConst.eq(ShlConst))
       return NN;
-    if (SarConst.isNegative())
+    if (SraConst.ult(ShlConst))
       return DAG.getNode(ISD::SHL, DL, VT, NN,
-                         DAG.getConstant(-SarConst, DL, CVT));
+                         DAG.getConstant(ShlConst - SraConst, DL, CVT));
     return DAG.getNode(ISD::SRA, DL, VT, NN,
-                       DAG.getConstant(SarConst, DL, CVT));
+                       DAG.getConstant(SraConst - ShlConst, DL, CVT));
   }
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/X86/sar_fold.ll b/llvm/test/CodeGen/X86/sar_fold.ll
index 7607ca386d577e..0f1396954b03a1 100644
--- a/llvm/test/CodeGen/X86/sar_fold.ll
+++ b/llvm/test/CodeGen/X86/sar_fold.ll
@@ -67,20 +67,17 @@ define void @shl144sar48(ptr %p) #0 {
   ret void
 }
 
-; This is incorrect. The 142 least significant bits in the stored value should
-; be zero, and but 142-157 should be taken from %a with a sign-extend into the
-; two most significant bits.
 define void @shl144sar2(ptr %p) #0 {
 ; CHECK-LABEL: shl144sar2:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; CHECK-NEXT:    movswl (%eax), %ecx
-; CHECK-NEXT:    sarl $31, %ecx
+; CHECK-NEXT:    shll $14, %ecx
 ; CHECK-NEXT:    movl %ecx, 16(%eax)
-; CHECK-NEXT:    movl %ecx, 8(%eax)
-; CHECK-NEXT:    movl %ecx, 12(%eax)
-; CHECK-NEXT:    movl %ecx, 4(%eax)
-; CHECK-NEXT:    movl %ecx, (%eax)
+; CHECK-NEXT:    movl $0, 8(%eax)
+; CHECK-NEXT:    movl $0, 12(%eax)
+; CHECK-NEXT:    movl $0, 4(%eax)
+; CHECK-NEXT:    movl $0, (%eax)
 ; CHECK-NEXT:    retl
   %a = load i160, ptr %p
   %1 = shl i160 %a, 144


        


More information about the llvm-commits mailing list