[llvm] [MIPS] Fix miscompile of 64-bit shift with masked shift amount (PR #71154)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 3 01:29:53 PDT 2023
https://github.com/yingopq created https://github.com/llvm/llvm-project/pull/71154
In function lowerShiftRightParts and lowerShiftLeftParts:
1. xor should use VT.bits-1 not -1;
2. The comments above the code are incorrect;
3. ShiftLeftLo and ShiftRightHi are wrong respectively.
Fix https://github.com/llvm/llvm-project/issues/64794
>From 3d0a76a0fed02d07307c129fd6fbe02d94b74ed9 Mon Sep 17 00:00:00 2001
From: Ying Huang <ying.huang at oss.cipunited.com>
Date: Fri, 3 Nov 2023 02:32:41 -0400
Subject: [PATCH] [MIPS] Fix miscompile of 64-bit shift with masked shift
amount
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
In function lowerShiftRightParts and lowerShiftLeftParts:
1. xor should use VT.bits-1 not -1;
2. The comments above the code are incorrect;
3. ShiftLeftLo and ShiftRightHi are wrong respectively.
Fix https://github.com/llvm/llvm-project/issues/64794
---
llvm/lib/Target/Mips/MipsISelLowering.cpp | 26 +++++++++++++++--------
1 file changed, 17 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 061d035b7e246c7..3f6122543c26c23 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -2593,18 +2593,22 @@ SDValue MipsTargetLowering::lowerShiftLeftParts(SDValue Op,
SDValue Shamt = Op.getOperand(2);
// if shamt < (VT.bits):
// lo = (shl lo, shamt)
- // hi = (or (shl hi, shamt) (srl (srl lo, 1), ~shamt))
+ // hi = (or (shl hi, shamt) (srl (srl lo, 1), (xor shamt, VT.bits-1)))
// else:
// lo = 0
// hi = (shl lo, shamt[4:0])
- SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
- DAG.getConstant(-1, DL, MVT::i32));
+ SDValue Not =
+ DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+ DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo,
DAG.getConstant(1, DL, VT));
SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, Not);
SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
- SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
+ SDValue ShamtMasked =
+ DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+ DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+ SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMasked);
SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond,
@@ -2623,7 +2627,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
MVT VT = Subtarget.isGP64bit() ? MVT::i64 : MVT::i32;
// if shamt < (VT.bits):
- // lo = (or (shl (shl hi, 1), ~shamt) (srl lo, shamt))
+ // lo = (or (shl (shl hi, 1), (xor shamt, VT.bits-1)) (srl lo, shamt))
// if isSRA:
// hi = (sra hi, shamt)
// else:
@@ -2635,15 +2639,19 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
// else:
// lo = (srl hi, shamt[4:0])
// hi = 0
- SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
- DAG.getConstant(-1, DL, MVT::i32));
+ SDValue Not =
+ DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+ DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
SDValue ShiftLeft1Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
DAG.getConstant(1, DL, VT));
SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, ShiftLeft1Hi, Not);
SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
- SDValue ShiftRightHi = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL,
- DL, VT, Hi, Shamt);
+ SDValue ShamtMasked =
+ DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+ DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+ SDValue ShiftRightHi =
+ DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
SDValue Ext = DAG.getNode(ISD::SRA, DL, VT, Hi,
More information about the llvm-commits
mailing list