[llvm] c8dc21d - [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 14 04:09:41 PDT 2024
Author: Froster
Date: 2024-07-14T12:09:37+01:00
New Revision: c8dc21d77fc82d9360953100aa328a13185f8ba0
URL: https://github.com/llvm/llvm-project/commit/c8dc21d77fc82d9360953100aa328a13185f8ba0
DIFF: https://github.com/llvm/llvm-project/commit/c8dc21d77fc82d9360953100aa328a13185f8ba0.diff
LOG: [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563)
Added a RISCV overload of `isTruncateFree` to fix the break of vnsrl described in issue #94265.
Fixes #94265
Added:
llvm/test/CodeGen/RISCV/pr94265.ll
Modified:
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 690a86bd4606c..92e18a4b630e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2586,6 +2586,17 @@ bool TargetLowering::SimplifyDemandedBits(
break;
if (Src.getNode()->hasOneUse()) {
+ if (isTruncateFree(Src, VT) &&
+ !isTruncateFree(Src.getValueType(), VT)) {
+ // If truncate is only free at trunc(srl), do not turn it into
+ // srl(trunc). The check is done by first check the truncate is free
+ // at Src's opcode(srl), then check the truncate is not done by
+ // referencing sub-register. In test, if both trunc(srl) and
+ // srl(trunc)'s trunc are free, srl(trunc) performs better. If only
+ // trunc(srl)'s trunc is free, trunc(srl) is better.
+ break;
+ }
+
std::optional<uint64_t> ShAmtC =
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
if (!ShAmtC || *ShAmtC >= BitWidth)
@@ -2596,7 +2607,6 @@ bool TargetLowering::SimplifyDemandedBits(
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
HighBits.lshrInPlace(ShVal);
HighBits = HighBits.trunc(BitWidth);
-
if (!(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b8ba25df9910b..caa4ebacc41da 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1894,6 +1894,21 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
return (SrcBits == 64 && DestBits == 32);
}
+bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
+ EVT SrcVT = Val.getValueType();
+ // free truncate from vnsrl and vnsra
+ if (Subtarget.hasStdExtV() &&
+ (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
+ SrcVT.isVector() && VT2.isVector()) {
+ unsigned SrcBits = SrcVT.getVectorElementType().getSizeInBits();
+ unsigned DestBits = VT2.getVectorElementType().getSizeInBits();
+ if (SrcBits == DestBits * 2) {
+ return true;
+ }
+ }
+ return TargetLowering::isTruncateFree(Val, VT2);
+}
+
bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
// Zexts are free if they can be combined with a load.
// Don't advertise i32->i64 zextload as being free for RV64. It interacts
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 7d8bceb5cb341..2642a188820e1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
bool isLegalAddImmediate(int64_t Imm) const override;
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
+ bool isTruncateFree(SDValue Val, EVT VT2) const override;
bool isZExtFree(SDValue Val, EVT VT2) const override;
bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
bool signExtendConstant(const ConstantInt *CI) const override;
diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
new file mode 100644
index 0000000000000..cb41e22381d19
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -0,0 +1,31 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
+; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s
+
+define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
+; RV32I-LABEL: PR94265:
+; RV32I: # %bb.0:
+; RV32I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; RV32I-NEXT: vsra.vi v10, v8, 31
+; RV32I-NEXT: vsrl.vi v10, v10, 26
+; RV32I-NEXT: vadd.vv v8, v8, v10
+; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; RV32I-NEXT: vnsrl.wi v10, v8, 6
+; RV32I-NEXT: vsll.vi v8, v10, 10
+; RV32I-NEXT: ret
+;
+; RV64I-LABEL: PR94265:
+; RV64I: # %bb.0:
+; RV64I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; RV64I-NEXT: vsra.vi v10, v8, 31
+; RV64I-NEXT: vsrl.vi v10, v10, 26
+; RV64I-NEXT: vadd.vv v8, v8, v10
+; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; RV64I-NEXT: vnsrl.wi v10, v8, 6
+; RV64I-NEXT: vsll.vi v8, v10, 10
+; RV64I-NEXT: ret
+ %t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
+ %t2 = trunc <8 x i32> %t1 to <8 x i16>
+ %t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+ ret <8 x i16> %t3
+}
More information about the llvm-commits
mailing list