[llvm] c8dc21d - [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 14 04:09:41 PDT 2024


Author: Froster
Date: 2024-07-14T12:09:37+01:00
New Revision: c8dc21d77fc82d9360953100aa328a13185f8ba0

URL: https://github.com/llvm/llvm-project/commit/c8dc21d77fc82d9360953100aa328a13185f8ba0
DIFF: https://github.com/llvm/llvm-project/commit/c8dc21d77fc82d9360953100aa328a13185f8ba0.diff

LOG: [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563)

Added a RISCV overload of `isTruncateFree` to fix the break of vnsrl described in issue #94265.

Fixes #94265

Added: 
    llvm/test/CodeGen/RISCV/pr94265.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 690a86bd4606c..92e18a4b630e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2586,6 +2586,17 @@ bool TargetLowering::SimplifyDemandedBits(
         break;
 
       if (Src.getNode()->hasOneUse()) {
+        if (isTruncateFree(Src, VT) &&
+            !isTruncateFree(Src.getValueType(), VT)) {
+          // If truncate is only free at trunc(srl), do not turn it into
+          // srl(trunc). The check is done by first check the truncate is free
+          // at Src's opcode(srl), then check the truncate is not done by
+          // referencing sub-register. In test, if both trunc(srl) and
+          // srl(trunc)'s trunc are free, srl(trunc) performs better. If only
+          // trunc(srl)'s trunc is free, trunc(srl) is better.
+          break;
+        }
+
         std::optional<uint64_t> ShAmtC =
             TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
         if (!ShAmtC || *ShAmtC >= BitWidth)
@@ -2596,7 +2607,6 @@ bool TargetLowering::SimplifyDemandedBits(
             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
         HighBits.lshrInPlace(ShVal);
         HighBits = HighBits.trunc(BitWidth);
-
         if (!(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b8ba25df9910b..caa4ebacc41da 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1894,6 +1894,21 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
   return (SrcBits == 64 && DestBits == 32);
 }
 
+bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
+  EVT SrcVT = Val.getValueType();
+  // free truncate from vnsrl and vnsra
+  if (Subtarget.hasStdExtV() &&
+      (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
+      SrcVT.isVector() && VT2.isVector()) {
+    unsigned SrcBits = SrcVT.getVectorElementType().getSizeInBits();
+    unsigned DestBits = VT2.getVectorElementType().getSizeInBits();
+    if (SrcBits == DestBits * 2) {
+      return true;
+    }
+  }
+  return TargetLowering::isTruncateFree(Val, VT2);
+}
+
 bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
   // Zexts are free if they can be combined with a load.
   // Don't advertise i32->i64 zextload as being free for RV64. It interacts

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 7d8bceb5cb341..2642a188820e1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
   bool isLegalAddImmediate(int64_t Imm) const override;
   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
   bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
+  bool isTruncateFree(SDValue Val, EVT VT2) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
   bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
   bool signExtendConstant(const ConstantInt *CI) const override;

diff  --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
new file mode 100644
index 0000000000000..cb41e22381d19
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -0,0 +1,31 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
+; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s
+
+define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
+; RV32I-LABEL: PR94265:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; RV32I-NEXT:    vsra.vi v10, v8, 31
+; RV32I-NEXT:    vsrl.vi v10, v10, 26
+; RV32I-NEXT:    vadd.vv v8, v8, v10
+; RV32I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; RV32I-NEXT:    vnsrl.wi v10, v8, 6
+; RV32I-NEXT:    vsll.vi v8, v10, 10
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: PR94265:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; RV64I-NEXT:    vsra.vi v10, v8, 31
+; RV64I-NEXT:    vsrl.vi v10, v10, 26
+; RV64I-NEXT:    vadd.vv v8, v8, v10
+; RV64I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; RV64I-NEXT:    vnsrl.wi v10, v8, 6
+; RV64I-NEXT:    vsll.vi v8, v10, 10
+; RV64I-NEXT:    ret
+  %t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
+  %t2 = trunc <8 x i32> %t1 to <8 x i16>
+  %t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+  ret <8 x i16> %t3
+}


        


More information about the llvm-commits mailing list