[llvm] [AArch64] Avoid overflow when using shl lower mul (PR #97148)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 29 03:08:36 PDT 2024


https://github.com/DianQK updated https://github.com/llvm/llvm-project/pull/97148

>From 03adde0431cd429d60baba866a8f69cb316861da Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sat, 29 Jun 2024 15:01:01 +0800
Subject: [PATCH 1/3] Pre-commit test cases

---
 llvm/test/CodeGen/AArch64/mul_pow2.ll | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index c4839175ded5a..2494bd1e6c9e5 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -992,3 +992,20 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
   %r = and <4 x i32> %a, <i32 131071, i32 131071, i32 131071, i32 131071>
   ret <4 x i32> %r
 }
+
+; Transforming `(mul x, -(2^(N-M) - 1) * 2^M)` to `(sub (shl x, M), (shl x, N))`
+; will cause overflow when N is 32 and M is 31.
+define i32 @shift_overflow(i32 %x) {
+; CHECK-LABEL: shift_overflow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: shift_overflow:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #-2147483648 // =0x80000000
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+  %const = bitcast i32 2147483648 to i32
+  %r = mul i32 %x, %const
+  ret i32 %r
+}

>From 5b24d14d5be5c628192b9cdf815dfbc404f1a865 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sat, 29 Jun 2024 15:02:08 +0800
Subject: [PATCH 2/3] [AArch64] Avoid overflow when using shl lower mul

Transforming `(mul x, -(2^(N-M) - 1) * 2^M)` to `(sub (shl x, M), (shl x, N))`
will cause overflow when N is 32 and M is 31.
---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++++++++-
 llvm/test/CodeGen/AArch64/mul_pow2.ll           |  2 ++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0d53f71a4def8..e022646969b62 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18059,16 +18059,28 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   unsigned ShiftAmt;
 
   auto Shl = [&](SDValue N0, unsigned N1) {
+    if (!N0.getNode())
+      return SDValue();
     SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
-    return DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
+    SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
+    // If shift causes overflow, ignore this combine.
+    if (SHL->isUndef())
+      return SDValue();
+    return SHL;
   };
   auto Add = [&](SDValue N0, SDValue N1) {
+    if (!N0.getNode() || !N1.getNode())
+      return SDValue();
     return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
   };
   auto Sub = [&](SDValue N0, SDValue N1) {
+    if (!N0.getNode() || !N1.getNode())
+      return SDValue();
     return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
   };
   auto Negate = [&](SDValue N) {
+    if (!N0.getNode())
+      return SDValue();
     SDValue Zero = DAG.getConstant(0, DL, VT);
     return DAG.getNode(ISD::SUB, DL, VT, Zero, N);
   };
diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 2494bd1e6c9e5..16a47c9a49a05 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -998,6 +998,8 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 define i32 @shift_overflow(i32 %x) {
 ; CHECK-LABEL: shift_overflow:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-2147483648 // =0x80000000
+; CHECK-NEXT:    mul w0, w0, w8
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: shift_overflow:

>From 5bbb7bdca79f054bdae3d9e69887cf6d1f371604 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sat, 29 Jun 2024 17:58:46 +0800
Subject: [PATCH 3/3] address comments

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  7 +++----
 llvm/test/CodeGen/AArch64/mul_pow2.ll         | 19 +++++++++++++++++++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e022646969b62..acce9515e832c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18061,12 +18061,11 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   auto Shl = [&](SDValue N0, unsigned N1) {
     if (!N0.getNode())
       return SDValue();
-    SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
-    SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
     // If shift causes overflow, ignore this combine.
-    if (SHL->isUndef())
+    if (N1 >= N0.getValueSizeInBits())
       return SDValue();
-    return SHL;
+    SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
+    return DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
   };
   auto Add = [&](SDValue N0, SDValue N1) {
     if (!N0.getNode() || !N1.getNode())
diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 16a47c9a49a05..7e26b877a4228 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -1011,3 +1011,22 @@ define i32 @shift_overflow(i32 %x) {
   %r = mul i32 %x, %const
   ret i32 %r
 }
+
+; Transforming `(mul x, -(2^(N-M) - 1) * 2^M)` to `(sub (shl x, M), (shl x, N))`
+; will not cause overflow when N is 31 and M is 30.
+define i32 @shift_no_overflow(i32 %x) {
+; CHECK-LABEL: shift_no_overflow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl w8, w0, #31
+; CHECK-NEXT:    sub w0, w8, w0, lsl #30
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: shift_no_overflow:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #1073741824 // =0x40000000
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+  %const = bitcast i32 1073741824 to i32
+  %r = mul i32 %x, %const
+  ret i32 %r
+}



More information about the llvm-commits mailing list