[llvm] fffb6e6 - [AArch64] Fix sub with carry

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Fri May 6 11:04:24 PDT 2022


Author: Kazu Hirata
Date: 2022-05-06T11:04:17-07:00
New Revision: fffb6e6afdbaba563189c1f715058ed401fbc88d

URL: https://github.com/llvm/llvm-project/commit/fffb6e6afdbaba563189c1f715058ed401fbc88d
DIFF: https://github.com/llvm/llvm-project/commit/fffb6e6afdbaba563189c1f715058ed401fbc88d.diff

LOG: [AArch64] Fix sub with carry

13403a70e45b2d22878ba59fc211f8dba3a8deba introduced a bug where we
generate the outgoing carry inverted, which in turn breaks the
lowering of @llvm.usub.sat.i128, returning the normal difference on
saturation and zero otherwise.

Note that AArch64 has peculiar semantics where the subtraction
instructions generate borrow inverted.  The problem is that we mix the
two forms of semantics -- the normal carry and inverted carry -- in
the area of extended precision subtractions.  Specifically, we have
three problems:

- lowerADDSUBCARRY takes the non-inverted incoming carry from a
  subtraction and feeds it to SBCS without inverting it first.

- lowerADDSUBCARRY makes available the outgoing carry from SBCS
  without inverting it.

- foldOverflowCheck folds:

  (SBC{S} l r (CMP (CSET LO carry) 1)) => (SBC{S} l r carry)

  When the incoming carry flag is set, CSET LO results in zero.  CMP
  in turn generates a borrow, *clearing* the carry flag.  Instead, we
  should fold:

  (SBC{S} l r (CMP 0 (CSET LO carry))) => (SBC{S} l r carry)

  When the incoming carry flag is set, CSET LO results in zero.  CMP
  does not generate a borrow, *setting* the carry flag.

IIUC, we should use the normal (that is, non-inverted) semantics for
carry everywhere.

This patch fixes the three problems above.

This patch does not add any new testcases because we have a plenty of
them covering the instruction in question.  In particular,
@u128_saturating_sub is identical to the testcase in the motivating
issue.

Fixes: #55253

Differential Revision: https://reviews.llvm.org/D124976

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/i128-math.ll
    llvm/test/CodeGen/AArch64/usub_sat_vec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f178db8d71b4..42c44eeeb23e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3308,23 +3308,29 @@ static SDValue LowerADDC_ADDE_SUBC_SUBE(SDValue Op, SelectionDAG &DAG) {
                      Op.getOperand(2));
 }
 
-// Sets 'C' bit of NZCV to 0 if value is 0, else sets 'C' bit to 1
-static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG) {
+// If Invert is false, sets 'C' bit of NZCV to 0 if value is 0, else sets 'C'
+// bit to 1. If Invert is true, sets 'C' bit of NZCV to 1 if value is 0, else
+// sets 'C' bit to 0.
+static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) {
   SDLoc DL(Value);
-  SDValue One = DAG.getConstant(1, DL, Value.getValueType());
+  EVT VT = Value.getValueType();
+  SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value;
+  SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT);
   SDValue Cmp =
-      DAG.getNode(AArch64ISD::SUBS, DL,
-                  DAG.getVTList(Value.getValueType(), MVT::Glue), Value, One);
+      DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1);
   return Cmp.getValue(1);
 }
 
-// Value is 1 if 'C' bit of NZCV is 1, else 0
-static SDValue carryFlagToValue(SDValue Flag, EVT VT, SelectionDAG &DAG) {
+// If Invert is false, value is 1 if 'C' bit of NZCV is 1, else 0.
+// If Invert is true, value is 0 if 'C' bit of NZCV is 1, else 1.
+static SDValue carryFlagToValue(SDValue Flag, EVT VT, SelectionDAG &DAG,
+                                bool Invert) {
   assert(Flag.getResNo() == 1);
   SDLoc DL(Flag);
   SDValue Zero = DAG.getConstant(0, DL, VT);
   SDValue One = DAG.getConstant(1, DL, VT);
-  SDValue CC = DAG.getConstant(AArch64CC::HS, DL, MVT::i32);
+  unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS;
+  SDValue CC = DAG.getConstant(Cond, DL, MVT::i32);
   return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Flag);
 }
 
@@ -3348,9 +3354,10 @@ static SDValue lowerADDSUBCARRY(SDValue Op, SelectionDAG &DAG, unsigned Opcode,
   if (VT0 != MVT::i32 && VT0 != MVT::i64)
     return SDValue();
 
+  bool InvertCarry = Opcode == AArch64ISD::SBCS;
   SDValue OpLHS = Op.getOperand(0);
   SDValue OpRHS = Op.getOperand(1);
-  SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG);
+  SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry);
 
   SDLoc DL(Op);
   SDVTList VTs = DAG.getVTList(VT0, VT1);
@@ -3358,8 +3365,9 @@ static SDValue lowerADDSUBCARRY(SDValue Op, SelectionDAG &DAG, unsigned Opcode,
   SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS,
                             OpRHS, OpCarryIn);
 
-  SDValue OutFlag = IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG)
-                             : carryFlagToValue(Sum.getValue(1), VT1, DAG);
+  SDValue OutFlag =
+      IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG)
+               : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry);
 
   return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag);
 }
@@ -15517,13 +15525,21 @@ static Optional<AArch64CC::CondCode> getCSETCondCode(SDValue Op) {
 }
 
 // (ADC{S} l r (CMP (CSET HS carry) 1)) => (ADC{S} l r carry)
-// (SBC{S} l r (CMP (CSET LO carry) 1)) => (SBC{S} l r carry)
+// (SBC{S} l r (CMP 0 (CSET LO carry))) => (SBC{S} l r carry)
 static SDValue foldOverflowCheck(SDNode *Op, SelectionDAG &DAG, bool IsAdd) {
   SDValue CmpOp = Op->getOperand(2);
-  if (!(isCMP(CmpOp) && isOneConstant(CmpOp.getOperand(1))))
+  if (!isCMP(CmpOp))
     return SDValue();
 
-  SDValue CsetOp = CmpOp->getOperand(0);
+  if (IsAdd) {
+    if (!isOneConstant(CmpOp.getOperand(1)))
+      return SDValue();
+  } else {
+    if (!isNullConstant(CmpOp.getOperand(0)))
+      return SDValue();
+  }
+
+  SDValue CsetOp = CmpOp->getOperand(IsAdd ? 0 : 1);
   auto CC = getCSETCondCode(CsetOp);
   if (CC != (IsAdd ? AArch64CC::HS : AArch64CC::LO))
     return SDValue();

diff  --git a/llvm/test/CodeGen/AArch64/i128-math.ll b/llvm/test/CodeGen/AArch64/i128-math.ll
index 4c5308411092..4dfdab71a18e 100644
--- a/llvm/test/CodeGen/AArch64/i128-math.ll
+++ b/llvm/test/CodeGen/AArch64/i128-math.ll
@@ -92,7 +92,7 @@ define { i128, i8 } @u128_checked_sub(i128 %x, i128 %y) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    subs x0, x0, x2
 ; CHECK-NEXT:    sbcs x1, x1, x3
-; CHECK-NEXT:    cset w8, hs
+; CHECK-NEXT:    cset w8, lo
 ; CHECK-NEXT:    eor w2, w8, #0x1
 ; CHECK-NEXT:    ret
   %1 = tail call { i128, i1 } @llvm.usub.with.overflow.i128(i128 %x, i128 %y)
@@ -110,7 +110,7 @@ define { i128, i8 } @u128_overflowing_sub(i128 %x, i128 %y) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    subs x0, x0, x2
 ; CHECK-NEXT:    sbcs x1, x1, x3
-; CHECK-NEXT:    cset w2, hs
+; CHECK-NEXT:    cset w2, lo
 ; CHECK-NEXT:    ret
   %1 = tail call { i128, i1 } @llvm.usub.with.overflow.i128(i128 %x, i128 %y)
   %2 = extractvalue { i128, i1 } %1, 0
@@ -126,7 +126,7 @@ define i128 @u128_saturating_sub(i128 %x, i128 %y) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    subs x8, x0, x2
 ; CHECK-NEXT:    sbcs x9, x1, x3
-; CHECK-NEXT:    cset w10, hs
+; CHECK-NEXT:    cset w10, lo
 ; CHECK-NEXT:    cmp w10, #0
 ; CHECK-NEXT:    csel x0, xzr, x8, ne
 ; CHECK-NEXT:    csel x1, xzr, x9, ne

diff  --git a/llvm/test/CodeGen/AArch64/usub_sat_vec.ll b/llvm/test/CodeGen/AArch64/usub_sat_vec.ll
index 703064693937..9ed64b7c7b2f 100644
--- a/llvm/test/CodeGen/AArch64/usub_sat_vec.ll
+++ b/llvm/test/CodeGen/AArch64/usub_sat_vec.ll
@@ -346,13 +346,13 @@ define <2 x i128> @v2i128(<2 x i128> %x, <2 x i128> %y) nounwind {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    subs x8, x2, x6
 ; CHECK-NEXT:    sbcs x9, x3, x7
-; CHECK-NEXT:    cset w10, hs
+; CHECK-NEXT:    cset w10, lo
 ; CHECK-NEXT:    cmp w10, #0
 ; CHECK-NEXT:    csel x2, xzr, x8, ne
 ; CHECK-NEXT:    csel x3, xzr, x9, ne
 ; CHECK-NEXT:    subs x8, x0, x4
 ; CHECK-NEXT:    sbcs x9, x1, x5
-; CHECK-NEXT:    cset w10, hs
+; CHECK-NEXT:    cset w10, lo
 ; CHECK-NEXT:    cmp w10, #0
 ; CHECK-NEXT:    csel x8, xzr, x8, ne
 ; CHECK-NEXT:    csel x1, xzr, x9, ne


        


More information about the llvm-commits mailing list