[llvm] [MIPS] Fix miscompile of 64-bit shift with masked shift amount (PR #71154)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 3 01:29:53 PDT 2023


https://github.com/yingopq created https://github.com/llvm/llvm-project/pull/71154

In function lowerShiftRightParts and lowerShiftLeftParts:
1. xor should use VT.bits-1 not -1;
2. The comments above the code are incorrect;
3. ShiftLeftLo and ShiftRightHi are wrong respectively.

Fix https://github.com/llvm/llvm-project/issues/64794

>From 3d0a76a0fed02d07307c129fd6fbe02d94b74ed9 Mon Sep 17 00:00:00 2001
From: Ying Huang <ying.huang at oss.cipunited.com>
Date: Fri, 3 Nov 2023 02:32:41 -0400
Subject: [PATCH] [MIPS] Fix miscompile of 64-bit shift with masked shift
 amount
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

In function lowerShiftRightParts and lowerShiftLeftParts:
1. xor should use VT.bits-1 not -1;
2. The comments above the code are incorrect;
3. ShiftLeftLo and ShiftRightHi are wrong respectively.

Fix https://github.com/llvm/llvm-project/issues/64794
---
 llvm/lib/Target/Mips/MipsISelLowering.cpp | 26 +++++++++++++++--------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 061d035b7e246c7..3f6122543c26c23 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -2593,18 +2593,22 @@ SDValue MipsTargetLowering::lowerShiftLeftParts(SDValue Op,
   SDValue Shamt = Op.getOperand(2);
   // if shamt < (VT.bits):
   //  lo = (shl lo, shamt)
-  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), ~shamt))
+  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), (xor shamt, VT.bits-1)))
   // else:
   //  lo = 0
   //  hi = (shl lo, shamt[4:0])
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+		  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo,
                                       DAG.getConstant(1, DL, VT));
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, Not);
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
-  SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+		  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMasked);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
   Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond,
@@ -2623,7 +2627,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   MVT VT = Subtarget.isGP64bit() ? MVT::i64 : MVT::i32;
 
   // if shamt < (VT.bits):
-  //  lo = (or (shl (shl hi, 1), ~shamt) (srl lo, shamt))
+  //  lo = (or (shl (shl hi, 1), (xor shamt, VT.bits-1)) (srl lo, shamt))
   //  if isSRA:
   //    hi = (sra hi, shamt)
   //  else:
@@ -2635,15 +2639,19 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   //  else:
   //   lo = (srl hi, shamt[4:0])
   //   hi = 0
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+		  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftLeft1Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
                                      DAG.getConstant(1, DL, VT));
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, ShiftLeft1Hi, Not);
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
-  SDValue ShiftRightHi = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL,
-                                     DL, VT, Hi, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+		  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftRightHi =
+      DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
   SDValue Ext = DAG.getNode(ISD::SRA, DL, VT, Hi,



More information about the llvm-commits mailing list