[llvm] 6c04b7d - [AArch64] Optimize overflow checks for [s|u]mul.with.overflow.i32.
Eli Friedman via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 12 15:31:14 PDT 2021
Author: Eli Friedman
Date: 2021-07-12T15:30:42-07:00
New Revision: 6c04b7dd4fb4bfcc5db10b844d6235abbb21b805
URL: https://github.com/llvm/llvm-project/commit/6c04b7dd4fb4bfcc5db10b844d6235abbb21b805
DIFF: https://github.com/llvm/llvm-project/commit/6c04b7dd4fb4bfcc5db10b844d6235abbb21b805.diff
LOG: [AArch64] Optimize overflow checks for [s|u]mul.with.overflow.i32.
Saves one instruction for signed, uses a cheaper instruction for
unsigned.
Differential Revision: https://reviews.llvm.org/D105770
Added:
Modified:
llvm/lib/Target/AArch64/AArch64FastISel.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/arm64-xaluo.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 88f0b545c67c..9acda17b816f 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -3681,11 +3681,13 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
if (VT == MVT::i32) {
MulReg = emitSMULL_rr(MVT::i64, LHSReg, RHSReg);
- unsigned ShiftReg = emitLSR_ri(MVT::i64, MVT::i64, MulReg, 32);
- MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
- ShiftReg = fastEmitInst_extractsubreg(VT, ShiftReg, AArch64::sub_32);
- emitSubs_rs(VT, ShiftReg, MulReg, AArch64_AM::ASR, 31,
- /*WantResult=*/false);
+ unsigned MulSubReg =
+ fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
+ // cmp xreg, wreg, sxtw
+ emitAddSub_rx(/*UseAdd=*/false, MVT::i64, MulReg, MulSubReg,
+ AArch64_AM::SXTW, /*ShiftImm=*/0, /*SetFlags=*/true,
+ /*WantResult=*/false);
+ MulReg = MulSubReg;
} else {
assert(VT == MVT::i64 && "Unexpected value type.");
// LHSReg and RHSReg cannot be killed by this Mul, since they are
@@ -3709,8 +3711,11 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
if (VT == MVT::i32) {
MulReg = emitUMULL_rr(MVT::i64, LHSReg, RHSReg);
- emitSubs_rs(MVT::i64, AArch64::XZR, MulReg, AArch64_AM::LSR, 32,
- /*WantResult=*/false);
+ // tst xreg, #0xffffffff00000000
+ BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
+ TII.get(AArch64::ANDSXri), AArch64::XZR)
+ .addReg(MulReg)
+ .addImm(AArch64_AM::encodeLogicalImmediate(0xFFFFFFFF00000000, 64));
MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32);
} else {
assert(VT == MVT::i64 && "Unexpected value type.");
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f9a90a01f7c5..662a1d458605 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3012,50 +3012,25 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
CC = AArch64CC::NE;
bool IsSigned = Op.getOpcode() == ISD::SMULO;
if (Op.getValueType() == MVT::i32) {
+ // Extend to 64-bits, then perform a 64-bit multiply.
unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- // For a 32 bit multiply with overflow check we want the instruction
- // selector to generate a widening multiply (SMADDL/UMADDL). For that we
- // need to generate the following pattern:
- // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b))
LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul,
- DAG.getConstant(0, DL, MVT::i64));
- // On AArch64 the upper 32 bits are always zero extended for a 32 bit
- // operation. We need to clear out the upper 32 bits, because we used a
- // widening multiply that wrote all 64 bits. In the end this should be a
- // noop.
- Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add);
+ Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul);
+
+ // Check that the result fits into a 32-bit integer.
+ SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC);
if (IsSigned) {
- // The signed overflow check requires more than just a simple check for
- // any bit set in the upper 32 bits of the result. These bits could be
- // just the sign bits of a negative number. To perform the overflow
- // check we have to arithmetic shift right the 32nd bit of the result by
- // 31 bits. Then we compare the result to the upper 32 bits.
- SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add,
- DAG.getConstant(32, DL, MVT::i64));
- UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits);
- SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value,
- DAG.getConstant(31, DL, MVT::i64));
- // It is important that LowerBits is last, otherwise the arithmetic
- // shift will not be folded into the compare (SUBS).
- SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32);
- Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
- .getValue(1);
+ // cmp xreg, wreg, sxtw
+ SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value);
+ Overflow =
+ DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1);
} else {
- // The overflow check for unsigned multiply is easy. We only need to
- // check if any of the upper 32 bits are set. This can be done with a
- // CMP (shifted register). For that we need to generate the following
- // pattern:
- // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32)
- SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul,
- DAG.getConstant(32, DL, MVT::i64));
- SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
+ // tst xreg, #0xffffffff00000000
+ SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64);
Overflow =
- DAG.getNode(AArch64ISD::SUBS, DL, VTs,
- DAG.getConstant(0, DL, MVT::i64),
- UpperBits).getValue(1);
+ DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1);
}
break;
}
diff --git a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
index 6ae5b3556413..d8f5db89954f 100644
--- a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll
@@ -202,8 +202,7 @@ define zeroext i1 @smulo.i32(i32 %v1, i32 %v2, i32* %res) {
entry:
; CHECK-LABEL: smulo.i32
; CHECK: smull x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw
; CHECK-NEXT: cset {{w[0-9]+}}, ne
%t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
%val = extractvalue {i32, i1} %t, 0
@@ -242,7 +241,7 @@ define zeroext i1 @umulo.i32(i32 %v1, i32 %v2, i32* %res) {
entry:
; CHECK-LABEL: umulo.i32
; CHECK: umull [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32
+; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000
; CHECK-NEXT: cset {{w[0-9]+}}, ne
%t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
%val = extractvalue {i32, i1} %t, 0
@@ -460,8 +459,7 @@ define i32 @smulo.select.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: smulo.select.i32
; CHECK: smull x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw
; CHECK-NEXT: csel w0, w0, w1, ne
%t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
%obit = extractvalue {i32, i1} %t, 1
@@ -473,8 +471,7 @@ define i1 @smulo.not.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: smulo.not.i32
; CHECK: smull x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32
-; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw
; CHECK-NEXT: cset w0, eq
%t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
%obit = extractvalue {i32, i1} %t, 1
@@ -512,7 +509,7 @@ define i32 @umulo.select.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: umulo.select.i32
; CHECK: umull [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32
+; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000
; CHECK-NEXT: csel w0, w0, w1, ne
%t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
%obit = extractvalue {i32, i1} %t, 1
@@ -524,7 +521,7 @@ define i1 @umulo.not.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: umulo.not.i32
; CHECK: umull [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32
+; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000
; CHECK-NEXT: cset w0, eq
%t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
%obit = extractvalue {i32, i1} %t, 1
@@ -700,8 +697,7 @@ define zeroext i1 @smulo.br.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: smulo.br.i32
; CHECK: smull x[[MREG:[0-9]+]], w0, w1
-; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x8, #32
-; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31
+; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw
; CHECK-NEXT: b.eq
%t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2)
%val = extractvalue {i32, i1} %t, 0
@@ -755,7 +751,7 @@ define zeroext i1 @umulo.br.i32(i32 %v1, i32 %v2) {
entry:
; CHECK-LABEL: umulo.br.i32
; CHECK: umull [[MREG:x[0-9]+]], w0, w1
-; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32
+; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000
; CHECK-NEXT: b.eq
%t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2)
%val = extractvalue {i32, i1} %t, 0
More information about the llvm-commits
mailing list