[llvm] [AArch64][InstCombine] Bail from combining SRAD on +/-1 divisor (PR #109274)

Matthew Devereau via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 19 05:54:41 PDT 2024


https://github.com/MDevereau updated https://github.com/llvm/llvm-project/pull/109274

>From b582da3e0fd6d089aa848b748b32b520d528bed5 Mon Sep 17 00:00:00 2001
From: Matt Devereau <matthew.devereau at arm.com>
Date: Wed, 18 Sep 2024 18:04:53 +0000
Subject: [PATCH 1/2] [AArch64][InstCombine] Bail from combining SRAD on +/-1
 divisor

This fixes a crash when svdiv's third parameter is svdup_s64(1)
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |  3 +++
 .../InstCombine/AArch64/sve-intrinsic-sdiv.ll | 24 +++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 649ba1ac8749f5..686503a794d7be 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1972,7 +1972,10 @@ static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
   if (!SplatConstantInt)
     return std::nullopt;
+
   APInt Divisor = SplatConstantInt->getValue();
+  if (Divisor.abs().getZExtValue() == 1)
+    return std::nullopt;
 
   if (Divisor.isPowerOf2()) {
     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
index 68d871355e2a23..dc134ef109fe75 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
@@ -68,6 +68,30 @@ define <vscale x 4 x i32> @sdiv_i32_not_zero(<vscale x 4 x i32> %a, <vscale x 4
   ret <vscale x 4 x i32> %out
 }
 
+; Don't instcombine to SRAD when the divisor is (+/-)1
+define <vscale x 2 x i64> @divide_by_1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
+
+define <vscale x 2 x i64> @divide_by_m1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_m1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 -1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
 
 declare <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
 declare <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)

>From 2a7cd161789a1f9a1232ca62f7ca8f7d1f059194 Mon Sep 17 00:00:00 2001
From: Matt Devereau <matthew.devereau at arm.com>
Date: Thu, 19 Sep 2024 12:47:41 +0000
Subject: [PATCH 2/2] Return 2nd param when divisor is 1

---
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp     | 5 ++++-
 .../Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll   | 7 +++----
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 686503a794d7be..a1371b21a7495e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1974,8 +1974,11 @@ static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
     return std::nullopt;
 
   APInt Divisor = SplatConstantInt->getValue();
-  if (Divisor.abs().getZExtValue() == 1)
+  const int64_t DivisorValue = Divisor.getSExtValue();
+  if (DivisorValue == -1)
     return std::nullopt;
+  if (DivisorValue == 1)
+    IC.replaceInstUsesWith(II, II.getOperand(1));
 
   if (Divisor.isPowerOf2()) {
     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
index dc134ef109fe75..cc184b79691eb7 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
@@ -68,12 +68,10 @@ define <vscale x 4 x i32> @sdiv_i32_not_zero(<vscale x 4 x i32> %a, <vscale x 4
   ret <vscale x 4 x i32> %out
 }
 
-; Don't instcombine to SRAD when the divisor is (+/-)1
+; Vec/1 is a no-op
 define <vscale x 2 x i64> @divide_by_1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
 ; CHECK-LABEL: @divide_by_1(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
-; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[A:%.*]]
 ;
   %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 1)
   %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
@@ -81,6 +79,7 @@ define <vscale x 2 x i64> @divide_by_1(<vscale x 16 x i1> %p, <vscale x 2 x i64>
   ret <vscale x 2 x i64> %3
 }
 
+; Don't instcombine to SRAD when the divisor is -1
 define <vscale x 2 x i64> @divide_by_m1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
 ; CHECK-LABEL: @divide_by_m1(
 ; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])



More information about the llvm-commits mailing list