[llvm] [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (PR #95563)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 29 12:31:33 PDT 2024
https://github.com/Fros1er updated https://github.com/llvm/llvm-project/pull/95563
>From a2018effea48b9526ab17feb58f30319a10894d8 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 14 Jun 2024 22:28:38 +0800
Subject: [PATCH 1/5] [SelectionDAG][RISCV] Add pre-commit tests.
---
llvm/test/CodeGen/RISCV/pr94265.ll | 35 ++++++++++++++++++++++++++++++
1 file changed, 35 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/pr94265.ll
diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
new file mode 100644
index 0000000000000..b1dff117eb17c
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
+; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s
+
+define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
+; RV32I-LABEL: PR94265:
+; RV32I: # %bb.0:
+; RV32I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; RV32I-NEXT: vsra.vi v10, v8, 31
+; RV32I-NEXT: vsrl.vi v10, v10, 26
+; RV32I-NEXT: vadd.vv v8, v8, v10
+; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; RV32I-NEXT: vnsrl.wi v10, v8, 0
+; RV32I-NEXT: vsll.vi v8, v10, 4
+; RV32I-NEXT: li a0, -1024
+; RV32I-NEXT: vand.vx v8, v8, a0
+; RV32I-NEXT: ret
+;
+; RV64I-LABEL: PR94265:
+; RV64I: # %bb.0:
+; RV64I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; RV64I-NEXT: vsra.vi v10, v8, 31
+; RV64I-NEXT: vsrl.vi v10, v10, 26
+; RV64I-NEXT: vadd.vv v8, v8, v10
+; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; RV64I-NEXT: vnsrl.wi v10, v8, 0
+; RV64I-NEXT: vsll.vi v8, v10, 4
+; RV64I-NEXT: li a0, -1024
+; RV64I-NEXT: vand.vx v8, v8, a0
+; RV64I-NEXT: ret
+ %t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
+ %t2 = trunc <8 x i32> %t1 to <8 x i16>
+ %t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+ ret <8 x i16> %t3
+}
>From 2c04c2327e14a5301b01c6eb6fd0f9aac71e2a05 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 14 Jun 2024 22:40:06 +0800
Subject: [PATCH 2/5] [SelectionDAG][RISCV] Add isTypeDesirableForOp with
NewVT+OldVT, fix issue#94265
---
llvm/include/llvm/CodeGen/TargetLowering.h | 14 ++++++++++++++
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 4 +++-
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 8 ++++++++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 ++
llvm/test/CodeGen/RISCV/pr94265.ll | 12 ++++--------
5 files changed, 31 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 3074ece787a08..f0e20e4372b8d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4339,6 +4339,20 @@ class TargetLowering : public TargetLoweringBase {
return isTypeLegal(VT);
}
+ /// Same as isTypeDesirableForOp(unsigned Opc, EVT VT), but also check if
+ /// the target is 'desirable' to truncate or extend OldVT to NewVT only using
+ /// the given node type, without the need of explicit trunc or ext. e.g. On
+ /// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
+ /// i16> when shifting, with no extra trunc operations needed.
+ virtual bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const {
+ // Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
+ if (NewVT == OldVT) {
+ return isTypeDesirableForOp(Opc, NewVT);
+ }
+ // Most of instructions are not desirable, so return false by default.
+ return false;
+ }
+
/// Return true if it is profitable for dag combiner to transform a floating
/// point op of specified opcode to a equivalent op of an integer
/// type. e.g. f32 load -> i32 load can be profitable on ARM.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 623d2e0a047ef..373aeac5e7317 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,7 +2597,9 @@ bool TargetLowering::SimplifyDemandedBits(
HighBits.lshrInPlace(ShVal);
HighBits = HighBits.trunc(BitWidth);
- if (!(HighBits & DemandedBits)) {
+ if (!isTypeDesirableForOp(ISD::SRL, Op.getValueType(),
+ Src.getValueType()) &&
+ !(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
SDValue NewShAmt =
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b1b27f03252e0..694e0b0dff1a3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17462,6 +17462,14 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
return true;
}
+bool RISCVTargetLowering::isTypeDesirableForOp(unsigned Opc, EVT NewVT,
+ EVT OldVT) const {
+ if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
+ return true;
+ }
+ return TargetLowering::isTypeDesirableForOp(Opc, NewVT, OldVT);
+}
+
bool RISCVTargetLowering::targetShrinkDemandedConstant(
SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
TargetLoweringOpt &TLO) const {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 3b8eb3c88901a..353836783ccfb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -708,6 +708,8 @@ class RISCVTargetLowering : public TargetLowering {
bool isDesirableToCommuteWithShift(const SDNode *N,
CombineLevel Level) const override;
+ bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const override;
+
/// If a physical register, this returns the register that receives the
/// exception address on entry to an EH pad.
Register
diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
index b1dff117eb17c..cb41e22381d19 100644
--- a/llvm/test/CodeGen/RISCV/pr94265.ll
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -10,10 +10,8 @@ define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
; RV32I-NEXT: vsrl.vi v10, v10, 26
; RV32I-NEXT: vadd.vv v8, v8, v10
; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
-; RV32I-NEXT: vnsrl.wi v10, v8, 0
-; RV32I-NEXT: vsll.vi v8, v10, 4
-; RV32I-NEXT: li a0, -1024
-; RV32I-NEXT: vand.vx v8, v8, a0
+; RV32I-NEXT: vnsrl.wi v10, v8, 6
+; RV32I-NEXT: vsll.vi v8, v10, 10
; RV32I-NEXT: ret
;
; RV64I-LABEL: PR94265:
@@ -23,10 +21,8 @@ define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
; RV64I-NEXT: vsrl.vi v10, v10, 26
; RV64I-NEXT: vadd.vv v8, v8, v10
; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
-; RV64I-NEXT: vnsrl.wi v10, v8, 0
-; RV64I-NEXT: vsll.vi v8, v10, 4
-; RV64I-NEXT: li a0, -1024
-; RV64I-NEXT: vand.vx v8, v8, a0
+; RV64I-NEXT: vnsrl.wi v10, v8, 6
+; RV64I-NEXT: vsll.vi v8, v10, 10
; RV64I-NEXT: ret
%t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
%t2 = trunc <8 x i32> %t1 to <8 x i16>
>From f6ab73cb4fb9f0fe40637f616e0f585cfa3ae534 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 28 Jun 2024 21:47:26 +0800
Subject: [PATCH 3/5] rename new func to isTypeDesirableForOpwithCast
---
llvm/include/llvm/CodeGen/TargetLowering.h | 3 ++-
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 4 ++--
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 6 +++---
llvm/lib/Target/RISCV/RISCVISelLowering.h | 3 ++-
4 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index f0e20e4372b8d..c94c0b1f9a4e7 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4344,7 +4344,8 @@ class TargetLowering : public TargetLoweringBase {
/// the given node type, without the need of explicit trunc or ext. e.g. On
/// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
/// i16> when shifting, with no extra trunc operations needed.
- virtual bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const {
+ virtual bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+ EVT OldVT) const {
// Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
if (NewVT == OldVT) {
return isTypeDesirableForOp(Opc, NewVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 373aeac5e7317..1a8748fa3d131 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,8 +2597,8 @@ bool TargetLowering::SimplifyDemandedBits(
HighBits.lshrInPlace(ShVal);
HighBits = HighBits.trunc(BitWidth);
- if (!isTypeDesirableForOp(ISD::SRL, Op.getValueType(),
- Src.getValueType()) &&
+ if (!isTypeDesirableForOpWithCast(ISD::SRL, Op.getValueType(),
+ Src.getValueType()) &&
!(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 694e0b0dff1a3..b1a3684835343 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17462,12 +17462,12 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
return true;
}
-bool RISCVTargetLowering::isTypeDesirableForOp(unsigned Opc, EVT NewVT,
- EVT OldVT) const {
+bool RISCVTargetLowering::isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+ EVT OldVT) const {
if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
return true;
}
- return TargetLowering::isTypeDesirableForOp(Opc, NewVT, OldVT);
+ return TargetLowering::isTypeDesirableForOpWithCast(Opc, NewVT, OldVT);
}
bool RISCVTargetLowering::targetShrinkDemandedConstant(
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 353836783ccfb..b79f8ca67bcd5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -708,7 +708,8 @@ class RISCVTargetLowering : public TargetLowering {
bool isDesirableToCommuteWithShift(const SDNode *N,
CombineLevel Level) const override;
- bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const override;
+ bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+ EVT OldVT) const override;
/// If a physical register, this returns the register that receives the
/// exception address on entry to an EH pad.
>From 6286d17512d4f95fdf4c616d45e60c11a60b3d39 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Sun, 30 Jun 2024 03:30:24 +0800
Subject: [PATCH 4/5] remove new func, use overrided isTruncateFree instead
---
llvm/include/llvm/CodeGen/TargetLowering.h | 15 ---------------
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 3 +--
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 16 ++++++++--------
llvm/lib/Target/RISCV/RISCVISelLowering.h | 4 +---
4 files changed, 10 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index c94c0b1f9a4e7..3074ece787a08 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4339,21 +4339,6 @@ class TargetLowering : public TargetLoweringBase {
return isTypeLegal(VT);
}
- /// Same as isTypeDesirableForOp(unsigned Opc, EVT VT), but also check if
- /// the target is 'desirable' to truncate or extend OldVT to NewVT only using
- /// the given node type, without the need of explicit trunc or ext. e.g. On
- /// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
- /// i16> when shifting, with no extra trunc operations needed.
- virtual bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
- EVT OldVT) const {
- // Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
- if (NewVT == OldVT) {
- return isTypeDesirableForOp(Opc, NewVT);
- }
- // Most of instructions are not desirable, so return false by default.
- return false;
- }
-
/// Return true if it is profitable for dag combiner to transform a floating
/// point op of specified opcode to a equivalent op of an integer
/// type. e.g. f32 load -> i32 load can be profitable on ARM.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 1a8748fa3d131..60cad8f5b30e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,8 +2597,7 @@ bool TargetLowering::SimplifyDemandedBits(
HighBits.lshrInPlace(ShVal);
HighBits = HighBits.trunc(BitWidth);
- if (!isTypeDesirableForOpWithCast(ISD::SRL, Op.getValueType(),
- Src.getValueType()) &&
+ if (!isTruncateFree(Src, Op.getValueType()) &&
!(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b1a3684835343..460ee29abd09f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1884,6 +1884,14 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
return (SrcBits == 64 && DestBits == 32);
}
+bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
+ // free truncate from vnsrl and vnsra
+ if (Subtarget.hasStdExtV() && (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) && Val.getValueType().isVector() && VT2.isVector()) {
+ return true;
+ }
+ return TargetLowering::isTruncateFree(Val, VT2);
+}
+
bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
// Zexts are free if they can be combined with a load.
// Don't advertise i32->i64 zextload as being free for RV64. It interacts
@@ -17462,14 +17470,6 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
return true;
}
-bool RISCVTargetLowering::isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
- EVT OldVT) const {
- if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
- return true;
- }
- return TargetLowering::isTypeDesirableForOpWithCast(Opc, NewVT, OldVT);
-}
-
bool RISCVTargetLowering::targetShrinkDemandedConstant(
SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
TargetLoweringOpt &TLO) const {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index b79f8ca67bcd5..d66374ec5b171 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
bool isLegalAddImmediate(int64_t Imm) const override;
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
+ bool isTruncateFree(SDValue Val, EVT VT2) const override;
bool isZExtFree(SDValue Val, EVT VT2) const override;
bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
bool signExtendConstant(const ConstantInt *CI) const override;
@@ -708,9 +709,6 @@ class RISCVTargetLowering : public TargetLowering {
bool isDesirableToCommuteWithShift(const SDNode *N,
CombineLevel Level) const override;
- bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
- EVT OldVT) const override;
-
/// If a physical register, this returns the register that receives the
/// exception address on entry to an EH pad.
Register
>From 261506d939a775c800c4bb3d7eab945336d24acf Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Sun, 30 Jun 2024 03:31:17 +0800
Subject: [PATCH 5/5] format
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 460ee29abd09f..77eef4d0501b5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1886,7 +1886,9 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
// free truncate from vnsrl and vnsra
- if (Subtarget.hasStdExtV() && (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) && Val.getValueType().isVector() && VT2.isVector()) {
+ if (Subtarget.hasStdExtV() &&
+ (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
+ Val.getValueType().isVector() && VT2.isVector()) {
return true;
}
return TargetLowering::isTruncateFree(Val, VT2);
More information about the llvm-commits
mailing list