[llvm] [RISCV] Generalize existing SRA combine to fix #101040. (PR #101610)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 21:00:26 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/101610

>From a6c16528c26513b4b9dd479c044f8a64d5a5517a Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 1 Aug 2024 20:38:36 -0700
Subject: [PATCH 1/3] [RISCV] Add tests for #101040. NFC

---
 llvm/test/CodeGen/RISCV/rv32zbb.ll | 38 ++++++++++++++++++++++++++++++
 llvm/test/CodeGen/RISCV/rv64zbb.ll | 38 ++++++++++++++++++++++++++++++
 2 files changed, 76 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rv32zbb.ll b/llvm/test/CodeGen/RISCV/rv32zbb.ll
index cb9fc6c16333e..7f96ba2a761a7 100644
--- a/llvm/test/CodeGen/RISCV/rv32zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zbb.ll
@@ -1417,3 +1417,41 @@ define i64 @orc_b_i64(i64 %a) {
   %2 = mul nuw i64 %1, 255
   ret i64 %2
 }
+
+define i32 @srai_slli(i16 signext %0) {
+; RV32I-LABEL: srai_slli:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    slli a0, a0, 25
+; RV32I-NEXT:    srai a0, a0, 31
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: srai_slli:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    slli a0, a0, 9
+; RV32ZBB-NEXT:    slli a0, a0, 16
+; RV32ZBB-NEXT:    srai a0, a0, 31
+; RV32ZBB-NEXT:    ret
+  %2 = shl i16 %0, 9
+  %sext = ashr i16 %2, 15
+  %3 = sext i16 %sext to i32
+  ret i32 %3
+}
+
+define i32 @srai_slli2(i16 signext %0) {
+; RV32I-LABEL: srai_slli2:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    slli a0, a0, 25
+; RV32I-NEXT:    srai a0, a0, 30
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: srai_slli2:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    slli a0, a0, 9
+; RV32ZBB-NEXT:    slli a0, a0, 16
+; RV32ZBB-NEXT:    srai a0, a0, 30
+; RV32ZBB-NEXT:    ret
+  %2 = shl i16 %0, 9
+  %sext = ashr i16 %2, 14
+  %3 = sext i16 %sext to i32
+  ret i32 %3
+}
diff --git a/llvm/test/CodeGen/RISCV/rv64zbb.ll b/llvm/test/CodeGen/RISCV/rv64zbb.ll
index 6c354cc1b446b..168f7ecd0cdc7 100644
--- a/llvm/test/CodeGen/RISCV/rv64zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zbb.ll
@@ -1560,3 +1560,41 @@ define i64 @orc_b_i64(i64 %a) {
   %2 = mul nuw i64 %1, 255
   ret i64 %2
 }
+
+define i64 @srai_slli(i16 signext %0) {
+; RV64I-LABEL: srai_slli:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a0, a0, 57
+; RV64I-NEXT:    srai a0, a0, 63
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: srai_slli:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    slli a0, a0, 9
+; RV64ZBB-NEXT:    slli a0, a0, 48
+; RV64ZBB-NEXT:    srai a0, a0, 63
+; RV64ZBB-NEXT:    ret
+  %2 = shl i16 %0, 9
+  %sext = ashr i16 %2, 15
+  %3 = sext i16 %sext to i64
+  ret i64 %3
+}
+
+define i64 @srai_slli2(i16 signext %0) {
+; RV64I-LABEL: srai_slli2:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a0, a0, 57
+; RV64I-NEXT:    srai a0, a0, 62
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: srai_slli2:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    slli a0, a0, 9
+; RV64ZBB-NEXT:    slli a0, a0, 48
+; RV64ZBB-NEXT:    srai a0, a0, 62
+; RV64ZBB-NEXT:    ret
+  %2 = shl i16 %0, 9
+  %sext = ashr i16 %2, 14
+  %3 = sext i16 %sext to i64
+  ret i64 %3
+}

>From 08019776fa3423208ee2ab4145039de32a3018df Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 1 Aug 2024 20:43:01 -0700
Subject: [PATCH 2/3] [RISCV] Generalize existing SRA combine to fix #101040.

We already had a DAG combine for  (sra (sext_inreg (shl X, C1), i32), C2)
-> (sra (shl X, C1+32), C2+32) that we used for RV64. This patch
generalizes it to other sext_inregs for both RV32 and RV64.

Fixes #101040.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 48 +++++++++++----------
 llvm/test/CodeGen/RISCV/rv32zbb.ll          | 34 +++++----------
 llvm/test/CodeGen/RISCV/rv64zbb.ll          |  6 +--
 3 files changed, 38 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68b614d1d3fdc..2efce6370d530 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1468,8 +1468,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
                        ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
                        ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
-  if (Subtarget.is64Bit())
-    setTargetDAGCombine(ISD::SRA);
+  setTargetDAGCombine(ISD::SRA);
 
   if (Subtarget.hasStdExtFOrZfinx())
     setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, ISD::FMUL});
@@ -15465,37 +15464,42 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
 
-  if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
+  EVT VT = N->getValueType(0);
+
+  if (VT != Subtarget.getXLenVT())
     return SDValue();
 
   if (!isa<ConstantSDNode>(N->getOperand(1)))
     return SDValue();
   uint64_t ShAmt = N->getConstantOperandVal(1);
-  if (ShAmt > 32)
-    return SDValue();
 
   SDValue N0 = N->getOperand(0);
 
-  // Combine (sra (sext_inreg (shl X, C1), i32), C2) ->
-  // (sra (shl X, C1+32), C2+32) so it gets selected as SLLI+SRAI instead of
-  // SLLIW+SRAIW. SLLI+SRAI have compressed forms.
-  if (ShAmt < 32 &&
-      N0.getOpcode() == ISD::SIGN_EXTEND_INREG && N0.hasOneUse() &&
-      cast<VTSDNode>(N0.getOperand(1))->getVT() == MVT::i32 &&
-      N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).hasOneUse() &&
-      isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) {
-    uint64_t LShAmt = N0.getOperand(0).getConstantOperandVal(1);
-    if (LShAmt < 32) {
-      SDLoc ShlDL(N0.getOperand(0));
-      SDValue Shl = DAG.getNode(ISD::SHL, ShlDL, MVT::i64,
-                                N0.getOperand(0).getOperand(0),
-                                DAG.getConstant(LShAmt + 32, ShlDL, MVT::i64));
-      SDLoc DL(N);
-      return DAG.getNode(ISD::SRA, DL, MVT::i64, Shl,
-                         DAG.getConstant(ShAmt + 32, DL, MVT::i64));
+  // Combine (sra (sext_inreg (shl X, C1), iX), C2) ->
+  // (sra (shl X, C1+(XLen-iX)), C2+(XLen-iX)) so it gets selected as SLLI+SRAI.
+  if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG && N0.hasOneUse()) {
+    unsigned ExtSize =
+        cast<VTSDNode>(N0.getOperand(1))->getVT().getSizeInBits();
+    if (ShAmt < ExtSize &&
+        N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).hasOneUse() &&
+        isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) {
+      uint64_t LShAmt = N0.getOperand(0).getConstantOperandVal(1);
+      if (LShAmt < ExtSize) {
+        unsigned Size = VT.getSizeInBits();
+        SDLoc ShlDL(N0.getOperand(0));
+        SDValue Shl = DAG.getNode(ISD::SHL, ShlDL, VT,
+                                  N0.getOperand(0).getOperand(0),
+                                  DAG.getConstant(LShAmt + (Size - ExtSize), ShlDL, VT));
+        SDLoc DL(N);
+        return DAG.getNode(ISD::SRA, DL, VT, Shl,
+                           DAG.getConstant(ShAmt + (Size - ExtSize), DL, VT));
+      }
     }
   }
 
+  if (ShAmt > 32 || VT != MVT::i64)
+    return SDValue();
+
   // Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
   // FIXME: Should this be a generic combine? There's a similar combine on X86.
   //
diff --git a/llvm/test/CodeGen/RISCV/rv32zbb.ll b/llvm/test/CodeGen/RISCV/rv32zbb.ll
index 7f96ba2a761a7..fa320f53cec6c 100644
--- a/llvm/test/CodeGen/RISCV/rv32zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zbb.ll
@@ -1419,18 +1419,11 @@ define i64 @orc_b_i64(i64 %a) {
 }
 
 define i32 @srai_slli(i16 signext %0) {
-; RV32I-LABEL: srai_slli:
-; RV32I:       # %bb.0:
-; RV32I-NEXT:    slli a0, a0, 25
-; RV32I-NEXT:    srai a0, a0, 31
-; RV32I-NEXT:    ret
-;
-; RV32ZBB-LABEL: srai_slli:
-; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    slli a0, a0, 9
-; RV32ZBB-NEXT:    slli a0, a0, 16
-; RV32ZBB-NEXT:    srai a0, a0, 31
-; RV32ZBB-NEXT:    ret
+; CHECK-LABEL: srai_slli:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 25
+; CHECK-NEXT:    srai a0, a0, 31
+; CHECK-NEXT:    ret
   %2 = shl i16 %0, 9
   %sext = ashr i16 %2, 15
   %3 = sext i16 %sext to i32
@@ -1438,18 +1431,11 @@ define i32 @srai_slli(i16 signext %0) {
 }
 
 define i32 @srai_slli2(i16 signext %0) {
-; RV32I-LABEL: srai_slli2:
-; RV32I:       # %bb.0:
-; RV32I-NEXT:    slli a0, a0, 25
-; RV32I-NEXT:    srai a0, a0, 30
-; RV32I-NEXT:    ret
-;
-; RV32ZBB-LABEL: srai_slli2:
-; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    slli a0, a0, 9
-; RV32ZBB-NEXT:    slli a0, a0, 16
-; RV32ZBB-NEXT:    srai a0, a0, 30
-; RV32ZBB-NEXT:    ret
+; CHECK-LABEL: srai_slli2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 25
+; CHECK-NEXT:    srai a0, a0, 30
+; CHECK-NEXT:    ret
   %2 = shl i16 %0, 9
   %sext = ashr i16 %2, 14
   %3 = sext i16 %sext to i32
diff --git a/llvm/test/CodeGen/RISCV/rv64zbb.ll b/llvm/test/CodeGen/RISCV/rv64zbb.ll
index 168f7ecd0cdc7..3ee9300dcc01e 100644
--- a/llvm/test/CodeGen/RISCV/rv64zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zbb.ll
@@ -1570,8 +1570,7 @@ define i64 @srai_slli(i16 signext %0) {
 ;
 ; RV64ZBB-LABEL: srai_slli:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    slli a0, a0, 9
-; RV64ZBB-NEXT:    slli a0, a0, 48
+; RV64ZBB-NEXT:    slli a0, a0, 57
 ; RV64ZBB-NEXT:    srai a0, a0, 63
 ; RV64ZBB-NEXT:    ret
   %2 = shl i16 %0, 9
@@ -1589,8 +1588,7 @@ define i64 @srai_slli2(i16 signext %0) {
 ;
 ; RV64ZBB-LABEL: srai_slli2:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    slli a0, a0, 9
-; RV64ZBB-NEXT:    slli a0, a0, 48
+; RV64ZBB-NEXT:    slli a0, a0, 57
 ; RV64ZBB-NEXT:    srai a0, a0, 62
 ; RV64ZBB-NEXT:    ret
   %2 = shl i16 %0, 9

>From c95b40d1b7c2015fbee17b51cb348e06ae7fb99d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 1 Aug 2024 21:00:10 -0700
Subject: [PATCH 3/3] fixup! format

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2efce6370d530..6056e9ce84294 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15480,16 +15480,16 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG && N0.hasOneUse()) {
     unsigned ExtSize =
         cast<VTSDNode>(N0.getOperand(1))->getVT().getSizeInBits();
-    if (ShAmt < ExtSize &&
-        N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).hasOneUse() &&
+    if (ShAmt < ExtSize && N0.getOperand(0).getOpcode() == ISD::SHL &&
+        N0.getOperand(0).hasOneUse() &&
         isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) {
       uint64_t LShAmt = N0.getOperand(0).getConstantOperandVal(1);
       if (LShAmt < ExtSize) {
         unsigned Size = VT.getSizeInBits();
         SDLoc ShlDL(N0.getOperand(0));
-        SDValue Shl = DAG.getNode(ISD::SHL, ShlDL, VT,
-                                  N0.getOperand(0).getOperand(0),
-                                  DAG.getConstant(LShAmt + (Size - ExtSize), ShlDL, VT));
+        SDValue Shl =
+            DAG.getNode(ISD::SHL, ShlDL, VT, N0.getOperand(0).getOperand(0),
+                        DAG.getConstant(LShAmt + (Size - ExtSize), ShlDL, VT));
         SDLoc DL(N);
         return DAG.getNode(ISD::SRA, DL, VT, Shl,
                            DAG.getConstant(ShAmt + (Size - ExtSize), DL, VT));



More information about the llvm-commits mailing list