[llvm] 6c04b7d - [AArch64] Optimize overflow checks for [s|u]mul.with.overflow.i32.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 12 15:31:14 PDT 2021


Author: Eli Friedman
Date: 2021-07-12T15:30:42-07:00
New Revision: 6c04b7dd4fb4bfcc5db10b844d6235abbb21b805

URL: https://github.com/llvm/llvm-project/commit/6c04b7dd4fb4bfcc5db10b844d6235abbb21b805
DIFF: https://github.com/llvm/llvm-project/commit/6c04b7dd4fb4bfcc5db10b844d6235abbb21b805.diff

LOG: [AArch64] Optimize overflow checks for [s|u]mul.with.overflow.i32.

Saves one instruction for signed, uses a cheaper instruction for
unsigned.

Differential Revision: https://reviews.llvm.org/D105770

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64FastISel.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/arm64-xaluo.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 88f0b545c67c..9acda17b816f 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -3681,11 +3681,13 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
 
       if (VT == MVT::i32) {
         MulReg = emitSMULL_rr(MVT::i64, LHSReg, RHSReg);
-        unsigned ShiftReg = emitLSR_ri(MVT::i64, MVT::i64, MulReg, 32);
-        MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
-        ShiftReg = fastEmitInst_extractsubreg(VT, ShiftReg, AArch64::sub_32);
-        emitSubs_rs(VT, ShiftReg, MulReg, AArch64_AM::ASR, 31,
-                    /*WantResult=*/false);
+        unsigned MulSubReg =
+            fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
+        // cmp xreg, wreg, sxtw
+        emitAddSub_rx(/*UseAdd=*/false, MVT::i64, MulReg, MulSubReg,
+                      AArch64_AM::SXTW, /*ShiftImm=*/0, /*SetFlags=*/true,
+                      /*WantResult=*/false);
+        MulReg = MulSubReg;
       } else {
         assert(VT == MVT::i64 && "Unexpected value type.");
         // LHSReg and RHSReg cannot be killed by this Mul, since they are
@@ -3709,8 +3711,11 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
 
       if (VT == MVT::i32) {
         MulReg = emitUMULL_rr(MVT::i64, LHSReg, RHSReg);
-        emitSubs_rs(MVT::i64, AArch64::XZR, MulReg, AArch64_AM::LSR, 32,
-                    /*WantResult=*/false);
+        // tst xreg, #0xffffffff00000000
+        BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
+                TII.get(AArch64::ANDSXri), AArch64::XZR)
+            .addReg(MulReg)
+            .addImm(AArch64_AM::encodeLogicalImmediate(0xFFFFFFFF00000000, 64));
         MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
       } else {
         assert(VT == MVT::i64 && "Unexpected value type.");

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f9a90a01f7c5..662a1d458605 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3012,50 +3012,25 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
     CC = AArch64CC::NE;
     bool IsSigned = Op.getOpcode() == ISD::SMULO;
     if (Op.getValueType() == MVT::i32) {
+      // Extend to 64-bits, then perform a 64-bit multiply.
       unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-      // For a 32 bit multiply with overflow check we want the instruction
-      // selector to generate a widening multiply (SMADDL/UMADDL). For that we
-      // need to generate the following pattern:
-      // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b))
       LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
       RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
       SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
-      SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul,
-                                DAG.getConstant(0, DL, MVT::i64));
-      // On AArch64 the upper 32 bits are always zero extended for a 32 bit
-      // operation. We need to clear out the upper 32 bits, because we used a
-      // widening multiply that wrote all 64 bits. In the end this should be a
-      // noop.
-      Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add);
+      Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul);
+
+      // Check that the result fits into a 32-bit integer.
+      SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC);
       if (IsSigned) {
-        // The signed overflow check requires more than just a simple check for
-        // any bit set in the upper 32 bits of the result. These bits could be
-        // just the sign bits of a negative number. To perform the overflow
-        // check we have to arithmetic shift right the 32nd bit of the result by
-        // 31 bits. Then we compare the result to the upper 32 bits.
-        SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add,
-                                        DAG.getConstant(32, DL, MVT::i64));
-        UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits);
-        SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value,
-                                        DAG.getConstant(31, DL, MVT::i64));
-        // It is important that LowerBits is last, otherwise the arithmetic
-        // shift will not be folded into the compare (SUBS).
-        SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32);
-        Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
-                       .getValue(1);
+        // cmp xreg, wreg, sxtw
+        SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value);
+        Overflow =
+            DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1);
       } else {
-        // The overflow check for unsigned multiply is easy. We only need to
-        // check if any of the upper 32 bits are set. This can be done with a
-        // CMP (shifted register). For that we need to generate the following
-        // pattern:
-        // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32)
-        SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul,
-                                        DAG.getConstant(32, DL, MVT::i64));
-        SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
+        // tst xreg, #0xffffffff00000000
+        SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64);
         Overflow =
-            DAG.getNode(AArch64ISD::SUBS, DL, VTs,
-                        DAG.getConstant(0, DL, MVT::i64),
-                        UpperBits).getValue(1);
+            DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1);
       }
       break;
     }

diff  --git a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
index 6ae5b3556413..d8f5db89954f 100644
--- a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
@@ -202,8 +202,7 @@ define zeroext i1 @smulo.i32(i32 %v1, i32 %v2, i32* %res) {
 entry:
 ; CHECK-LABEL:  smulo.i32
 ; CHECK:        smull x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT:   lsr x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT:   cmp w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT:   cmp x[[MREG]], w[[MREG]], sxtw
 ; CHECK-NEXT:   cset {{w[0-9]+}}, ne
   %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
   %val = extractvalue {i32, i1} %t, 0
@@ -242,7 +241,7 @@ define zeroext i1 @umulo.i32(i32 %v1, i32 %v2, i32* %res) {
 entry:
 ; CHECK-LABEL:  umulo.i32
 ; CHECK:        umull [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT:   cmp xzr, [[MREG]], lsr #32
+; CHECK-NEXT:   tst [[MREG]], #0xffffffff00000000
 ; CHECK-NEXT:   cset {{w[0-9]+}}, ne
   %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
   %val = extractvalue {i32, i1} %t, 0
@@ -460,8 +459,7 @@ define i32 @smulo.select.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  smulo.select.i32
 ; CHECK:        smull   x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT:   lsr     x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT:   cmp     w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT:   cmp     x[[MREG]], w[[MREG]], sxtw
 ; CHECK-NEXT:   csel    w0, w0, w1, ne
   %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
   %obit = extractvalue {i32, i1} %t, 1
@@ -473,8 +471,7 @@ define i1 @smulo.not.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  smulo.not.i32
 ; CHECK:        smull   x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT:   lsr     x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT:   cmp     w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT:   cmp     x[[MREG]], w[[MREG]], sxtw
 ; CHECK-NEXT:   cset    w0, eq
   %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
   %obit = extractvalue {i32, i1} %t, 1
@@ -512,7 +509,7 @@ define i32 @umulo.select.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  umulo.select.i32
 ; CHECK:        umull   [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT:   cmp     xzr, [[MREG]], lsr #32
+; CHECK-NEXT:   tst     [[MREG]], #0xffffffff00000000
 ; CHECK-NEXT:   csel    w0, w0, w1, ne
   %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
   %obit = extractvalue {i32, i1} %t, 1
@@ -524,7 +521,7 @@ define i1 @umulo.not.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  umulo.not.i32
 ; CHECK:        umull   [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT:   cmp     xzr, [[MREG]], lsr #32
+; CHECK-NEXT:   tst     [[MREG]], #0xffffffff00000000
 ; CHECK-NEXT:   cset    w0, eq
   %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
   %obit = extractvalue {i32, i1} %t, 1
@@ -700,8 +697,7 @@ define zeroext i1 @smulo.br.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  smulo.br.i32
 ; CHECK:        smull   x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT:   lsr     x[[SREG:[0-9]+]], x8, #32
-; CHECK-NEXT:   cmp     w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT:   cmp     x[[MREG]], w[[MREG]], sxtw
 ; CHECK-NEXT:   b.eq
   %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
   %val = extractvalue {i32, i1} %t, 0
@@ -755,7 +751,7 @@ define zeroext i1 @umulo.br.i32(i32 %v1, i32 %v2) {
 entry:
 ; CHECK-LABEL:  umulo.br.i32
 ; CHECK:        umull   [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT:   cmp     xzr, [[MREG]], lsr #32
+; CHECK-NEXT:   tst     [[MREG]], #0xffffffff00000000
 ; CHECK-NEXT:   b.eq
   %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
   %val = extractvalue {i32, i1} %t, 0


        


More information about the llvm-commits mailing list