[llvm] [SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (PR #95563)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 29 12:31:33 PDT 2024


https://github.com/Fros1er updated https://github.com/llvm/llvm-project/pull/95563

>From a2018effea48b9526ab17feb58f30319a10894d8 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 14 Jun 2024 22:28:38 +0800
Subject: [PATCH 1/5] [SelectionDAG][RISCV] Add pre-commit tests.

---
 llvm/test/CodeGen/RISCV/pr94265.ll | 35 ++++++++++++++++++++++++++++++
 1 file changed, 35 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/pr94265.ll

diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
new file mode 100644
index 0000000000000..b1dff117eb17c
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
+; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s
+
+define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
+; RV32I-LABEL: PR94265:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; RV32I-NEXT:    vsra.vi v10, v8, 31
+; RV32I-NEXT:    vsrl.vi v10, v10, 26
+; RV32I-NEXT:    vadd.vv v8, v8, v10
+; RV32I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; RV32I-NEXT:    vnsrl.wi v10, v8, 0
+; RV32I-NEXT:    vsll.vi v8, v10, 4
+; RV32I-NEXT:    li a0, -1024
+; RV32I-NEXT:    vand.vx v8, v8, a0
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: PR94265:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; RV64I-NEXT:    vsra.vi v10, v8, 31
+; RV64I-NEXT:    vsrl.vi v10, v10, 26
+; RV64I-NEXT:    vadd.vv v8, v8, v10
+; RV64I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; RV64I-NEXT:    vnsrl.wi v10, v8, 0
+; RV64I-NEXT:    vsll.vi v8, v10, 4
+; RV64I-NEXT:    li a0, -1024
+; RV64I-NEXT:    vand.vx v8, v8, a0
+; RV64I-NEXT:    ret
+  %t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
+  %t2 = trunc <8 x i32> %t1 to <8 x i16>
+  %t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+  ret <8 x i16> %t3
+}

>From 2c04c2327e14a5301b01c6eb6fd0f9aac71e2a05 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 14 Jun 2024 22:40:06 +0800
Subject: [PATCH 2/5] [SelectionDAG][RISCV] Add isTypeDesirableForOp with
 NewVT+OldVT, fix issue#94265

---
 llvm/include/llvm/CodeGen/TargetLowering.h       | 14 ++++++++++++++
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp |  4 +++-
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp      |  8 ++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h        |  2 ++
 llvm/test/CodeGen/RISCV/pr94265.ll               | 12 ++++--------
 5 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 3074ece787a08..f0e20e4372b8d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4339,6 +4339,20 @@ class TargetLowering : public TargetLoweringBase {
     return isTypeLegal(VT);
   }
 
+  /// Same as isTypeDesirableForOp(unsigned Opc, EVT VT), but also check if
+  /// the target is 'desirable' to truncate or extend OldVT to NewVT only using
+  /// the given node type, without the need of explicit trunc or ext. e.g. On
+  /// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
+  /// i16> when shifting, with no extra trunc operations needed.
+  virtual bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const {
+    // Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
+    if (NewVT == OldVT) {
+      return isTypeDesirableForOp(Opc, NewVT);
+    }
+    // Most of instructions are not desirable, so return false by default.
+    return false;
+  }
+
   /// Return true if it is profitable for dag combiner to transform a floating
   /// point op of specified opcode to a equivalent op of an integer
   /// type. e.g. f32 load -> i32 load can be profitable on ARM.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 623d2e0a047ef..373aeac5e7317 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,7 +2597,9 @@ bool TargetLowering::SimplifyDemandedBits(
         HighBits.lshrInPlace(ShVal);
         HighBits = HighBits.trunc(BitWidth);
 
-        if (!(HighBits & DemandedBits)) {
+        if (!isTypeDesirableForOp(ISD::SRL, Op.getValueType(),
+                                  Src.getValueType()) &&
+            !(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.
           SDValue NewShAmt =
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b1b27f03252e0..694e0b0dff1a3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17462,6 +17462,14 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
   return true;
 }
 
+bool RISCVTargetLowering::isTypeDesirableForOp(unsigned Opc, EVT NewVT,
+                                               EVT OldVT) const {
+  if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
+    return true;
+  }
+  return TargetLowering::isTypeDesirableForOp(Opc, NewVT, OldVT);
+}
+
 bool RISCVTargetLowering::targetShrinkDemandedConstant(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
     TargetLoweringOpt &TLO) const {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 3b8eb3c88901a..353836783ccfb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -708,6 +708,8 @@ class RISCVTargetLowering : public TargetLowering {
   bool isDesirableToCommuteWithShift(const SDNode *N,
                                      CombineLevel Level) const override;
 
+  bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const override;
+
   /// If a physical register, this returns the register that receives the
   /// exception address on entry to an EH pad.
   Register
diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll
index b1dff117eb17c..cb41e22381d19 100644
--- a/llvm/test/CodeGen/RISCV/pr94265.ll
+++ b/llvm/test/CodeGen/RISCV/pr94265.ll
@@ -10,10 +10,8 @@ define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
 ; RV32I-NEXT:    vsrl.vi v10, v10, 26
 ; RV32I-NEXT:    vadd.vv v8, v8, v10
 ; RV32I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
-; RV32I-NEXT:    vnsrl.wi v10, v8, 0
-; RV32I-NEXT:    vsll.vi v8, v10, 4
-; RV32I-NEXT:    li a0, -1024
-; RV32I-NEXT:    vand.vx v8, v8, a0
+; RV32I-NEXT:    vnsrl.wi v10, v8, 6
+; RV32I-NEXT:    vsll.vi v8, v10, 10
 ; RV32I-NEXT:    ret
 ;
 ; RV64I-LABEL: PR94265:
@@ -23,10 +21,8 @@ define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
 ; RV64I-NEXT:    vsrl.vi v10, v10, 26
 ; RV64I-NEXT:    vadd.vv v8, v8, v10
 ; RV64I-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
-; RV64I-NEXT:    vnsrl.wi v10, v8, 0
-; RV64I-NEXT:    vsll.vi v8, v10, 4
-; RV64I-NEXT:    li a0, -1024
-; RV64I-NEXT:    vand.vx v8, v8, a0
+; RV64I-NEXT:    vnsrl.wi v10, v8, 6
+; RV64I-NEXT:    vsll.vi v8, v10, 10
 ; RV64I-NEXT:    ret
   %t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
   %t2 = trunc <8 x i32> %t1 to <8 x i16>

>From f6ab73cb4fb9f0fe40637f616e0f585cfa3ae534 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Fri, 28 Jun 2024 21:47:26 +0800
Subject: [PATCH 3/5] rename new func to isTypeDesirableForOpwithCast

---
 llvm/include/llvm/CodeGen/TargetLowering.h       | 3 ++-
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 4 ++--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp      | 6 +++---
 llvm/lib/Target/RISCV/RISCVISelLowering.h        | 3 ++-
 4 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index f0e20e4372b8d..c94c0b1f9a4e7 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4344,7 +4344,8 @@ class TargetLowering : public TargetLoweringBase {
   /// the given node type, without the need of explicit trunc or ext. e.g. On
   /// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
   /// i16> when shifting, with no extra trunc operations needed.
-  virtual bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const {
+  virtual bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+                                            EVT OldVT) const {
     // Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
     if (NewVT == OldVT) {
       return isTypeDesirableForOp(Opc, NewVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 373aeac5e7317..1a8748fa3d131 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,8 +2597,8 @@ bool TargetLowering::SimplifyDemandedBits(
         HighBits.lshrInPlace(ShVal);
         HighBits = HighBits.trunc(BitWidth);
 
-        if (!isTypeDesirableForOp(ISD::SRL, Op.getValueType(),
-                                  Src.getValueType()) &&
+        if (!isTypeDesirableForOpWithCast(ISD::SRL, Op.getValueType(),
+                                          Src.getValueType()) &&
             !(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 694e0b0dff1a3..b1a3684835343 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17462,12 +17462,12 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
   return true;
 }
 
-bool RISCVTargetLowering::isTypeDesirableForOp(unsigned Opc, EVT NewVT,
-                                               EVT OldVT) const {
+bool RISCVTargetLowering::isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+                                                       EVT OldVT) const {
   if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
     return true;
   }
-  return TargetLowering::isTypeDesirableForOp(Opc, NewVT, OldVT);
+  return TargetLowering::isTypeDesirableForOpWithCast(Opc, NewVT, OldVT);
 }
 
 bool RISCVTargetLowering::targetShrinkDemandedConstant(
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 353836783ccfb..b79f8ca67bcd5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -708,7 +708,8 @@ class RISCVTargetLowering : public TargetLowering {
   bool isDesirableToCommuteWithShift(const SDNode *N,
                                      CombineLevel Level) const override;
 
-  bool isTypeDesirableForOp(unsigned Opc, EVT NewVT, EVT OldVT) const override;
+  bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
+                                    EVT OldVT) const override;
 
   /// If a physical register, this returns the register that receives the
   /// exception address on entry to an EH pad.

>From 6286d17512d4f95fdf4c616d45e60c11a60b3d39 Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Sun, 30 Jun 2024 03:30:24 +0800
Subject: [PATCH 4/5] remove new func, use overrided isTruncateFree instead

---
 llvm/include/llvm/CodeGen/TargetLowering.h       | 15 ---------------
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp |  3 +--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp      | 16 ++++++++--------
 llvm/lib/Target/RISCV/RISCVISelLowering.h        |  4 +---
 4 files changed, 10 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index c94c0b1f9a4e7..3074ece787a08 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4339,21 +4339,6 @@ class TargetLowering : public TargetLoweringBase {
     return isTypeLegal(VT);
   }
 
-  /// Same as isTypeDesirableForOp(unsigned Opc, EVT VT), but also check if
-  /// the target is 'desirable' to truncate or extend OldVT to NewVT only using
-  /// the given node type, without the need of explicit trunc or ext. e.g. On
-  /// RISC-V Vector extension, vnsrl.wi can directly convert <n x i32> to <n x
-  /// i16> when shifting, with no extra trunc operations needed.
-  virtual bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
-                                            EVT OldVT) const {
-    // Fallback to isTypeDesirableForOp(unsigned Opc, EVT VT).
-    if (NewVT == OldVT) {
-      return isTypeDesirableForOp(Opc, NewVT);
-    }
-    // Most of instructions are not desirable, so return false by default.
-    return false;
-  }
-
   /// Return true if it is profitable for dag combiner to transform a floating
   /// point op of specified opcode to a equivalent op of an integer
   /// type. e.g. f32 load -> i32 load can be profitable on ARM.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 1a8748fa3d131..60cad8f5b30e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2597,8 +2597,7 @@ bool TargetLowering::SimplifyDemandedBits(
         HighBits.lshrInPlace(ShVal);
         HighBits = HighBits.trunc(BitWidth);
 
-        if (!isTypeDesirableForOpWithCast(ISD::SRL, Op.getValueType(),
-                                          Src.getValueType()) &&
+        if (!isTruncateFree(Src, Op.getValueType()) &&
             !(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b1a3684835343..460ee29abd09f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1884,6 +1884,14 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
   return (SrcBits == 64 && DestBits == 32);
 }
 
+bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
+  // free truncate from vnsrl and vnsra
+  if (Subtarget.hasStdExtV() && (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) && Val.getValueType().isVector() && VT2.isVector()) {
+    return true;
+  }
+  return TargetLowering::isTruncateFree(Val, VT2);
+}
+
 bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
   // Zexts are free if they can be combined with a load.
   // Don't advertise i32->i64 zextload as being free for RV64. It interacts
@@ -17462,14 +17470,6 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
   return true;
 }
 
-bool RISCVTargetLowering::isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
-                                                       EVT OldVT) const {
-  if (Subtarget.hasStdExtV() && NewVT.isVector() && OldVT.isVector()) {
-    return true;
-  }
-  return TargetLowering::isTypeDesirableForOpWithCast(Opc, NewVT, OldVT);
-}
-
 bool RISCVTargetLowering::targetShrinkDemandedConstant(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
     TargetLoweringOpt &TLO) const {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index b79f8ca67bcd5..d66374ec5b171 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
   bool isLegalAddImmediate(int64_t Imm) const override;
   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
   bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
+  bool isTruncateFree(SDValue Val, EVT VT2) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
   bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
   bool signExtendConstant(const ConstantInt *CI) const override;
@@ -708,9 +709,6 @@ class RISCVTargetLowering : public TargetLowering {
   bool isDesirableToCommuteWithShift(const SDNode *N,
                                      CombineLevel Level) const override;
 
-  bool isTypeDesirableForOpWithCast(unsigned Opc, EVT NewVT,
-                                    EVT OldVT) const override;
-
   /// If a physical register, this returns the register that receives the
   /// exception address on entry to an EH pad.
   Register

>From 261506d939a775c800c4bb3d7eab945336d24acf Mon Sep 17 00:00:00 2001
From: Fros1er <34234343+Fros1er at users.noreply.github.com>
Date: Sun, 30 Jun 2024 03:31:17 +0800
Subject: [PATCH 5/5] format

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 460ee29abd09f..77eef4d0501b5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1886,7 +1886,9 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
 
 bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
   // free truncate from vnsrl and vnsra
-  if (Subtarget.hasStdExtV() && (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) && Val.getValueType().isVector() && VT2.isVector()) {
+  if (Subtarget.hasStdExtV() &&
+      (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
+      Val.getValueType().isVector() && VT2.isVector()) {
     return true;
   }
   return TargetLowering::isTruncateFree(Val, VT2);



More information about the llvm-commits mailing list