[llvm] 51d6729 - [RISCV] Fold (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 30 09:02:21 PDT 2022


Author: Craig Topper
Date: 2022-06-30T09:01:24-07:00
New Revision: 51d672946efdfacc06948cd46b51109b07ac12e5

URL: https://github.com/llvm/llvm-project/commit/51d672946efdfacc06948cd46b51109b07ac12e5
DIFF: https://github.com/llvm/llvm-project/commit/51d672946efdfacc06948cd46b51109b07ac12e5.diff

LOG: [RISCV] Fold (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)

Similar for a subtract with a constant left hand side.

(sra (add (shl X, 32), C1<<32), 32) is the canonical IR from InstCombine
for (sext (add (trunc X to i32), 32) to i32).

For RISCV, we should lower this as addiw which means turning it into
(sext_inreg (add X, C1)).

There is an existing DAG combine to convert back to (sext (add (trunc X
to i32), 32) to i32), but it requires isTruncateFree to return true
and for i32 to be a legal type as it used sign_extend and truncate
nodes. So that doesn't work for RISCV.

If the outer sra happens be used by a shl by constant, it will be
folded and the shift amount of the sra will be changed before we
can do our own DAG combine. This requires us to match the more
general pattern and restore the shl.

I had wanted to do this as a separate (add (shl X, 32), C1<<32) ->
(shl (add X, C1), 32) combine, but that hit an infinite loop for some
values of C1.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D128869

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b676a2bc5e39d..ff645dea4e7af 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8532,6 +8532,10 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
 
 // Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
 // FIXME: Should this be a generic combine? There's a similar combine on X86.
+//
+// Also try these folds where an add or sub is in the middle.
+// (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)
+// (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C)
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -8539,21 +8543,63 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
   if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
     return SDValue();
 
-  auto *C = dyn_cast<ConstantSDNode>(N->getOperand(1));
-  if (!C || C->getZExtValue() >= 32)
+  auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1));
+  if (!ShAmtC || ShAmtC->getZExtValue() > 32)
     return SDValue();
 
   SDValue N0 = N->getOperand(0);
-  if (N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||
-      !isa<ConstantSDNode>(N0.getOperand(1)) ||
-      N0.getConstantOperandVal(1) != 32)
+
+  SDValue Shl;
+  ConstantSDNode *AddC = nullptr;
+
+  // We might have an ADD or SUB between the SRA and SHL.
+  bool IsAdd = N0.getOpcode() == ISD::ADD;
+  if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
+    if (!N0.hasOneUse())
+      return SDValue();
+    // Other operand needs to be a constant we can modify.
+    AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
+    if (!AddC)
+      return SDValue();
+
+    // AddC needs to have at least 32 trailing zeros.
+    if (AddC->getAPIntValue().countTrailingZeros() < 32)
+      return SDValue();
+
+    Shl = N0.getOperand(IsAdd ? 0 : 1);
+  } else {
+    // Not an ADD or SUB.
+    Shl = N0;
+  }
+
+  // Look for a shift left by 32.
+  if (Shl.getOpcode() != ISD::SHL || !Shl.hasOneUse() ||
+      !isa<ConstantSDNode>(Shl.getOperand(1)) ||
+      Shl.getConstantOperandVal(1) != 32)
     return SDValue();
 
   SDLoc DL(N);
-  SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64,
-                             N0.getOperand(0), DAG.getValueType(MVT::i32));
-  return DAG.getNode(ISD::SHL, DL, MVT::i64, SExt,
-                     DAG.getConstant(32 - C->getZExtValue(), DL, MVT::i64));
+  SDValue In = Shl.getOperand(0);
+
+  // If we looked through an ADD or SUB, we need to rebuild it with the shifted
+  // constant.
+  if (AddC) {
+    SDValue ShiftedAddC =
+        DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64);
+    if (IsAdd)
+      In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC);
+    else
+      In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In);
+  }
+
+  SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In,
+                             DAG.getValueType(MVT::i32));
+  if (ShAmtC->getZExtValue() == 32)
+    return SExt;
+
+  return DAG.getNode(
+      ISD::SHL, DL, MVT::i64, SExt,
+      DAG.getConstant(32 - ShAmtC->getZExtValue(), DL, MVT::i64));
 }
 
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

diff  --git a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll
index f77fc58e403b2..ab6e1127b9070 100644
--- a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll
+++ b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll
@@ -84,11 +84,7 @@ define i64 @test6(i32 signext %a, i32 signext %b) nounwind {
 define i64 @test7(i32* %0, i64 %1) {
 ; RV64I-LABEL: test7:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    slli a0, a1, 32
-; RV64I-NEXT:    li a1, 1
-; RV64I-NEXT:    slli a1, a1, 32
-; RV64I-NEXT:    add a0, a0, a1
-; RV64I-NEXT:    srai a0, a0, 32
+; RV64I-NEXT:    addiw a0, a1, 1
 ; RV64I-NEXT:    ret
   %3 = shl i64 %1, 32
   %4 = add i64 %3, 4294967296
@@ -102,11 +98,8 @@ define i64 @test7(i32* %0, i64 %1) {
 define i64 @test8(i32* %0, i64 %1) {
 ; RV64I-LABEL: test8:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    slli a0, a1, 32
-; RV64I-NEXT:    li a1, 1
-; RV64I-NEXT:    slli a1, a1, 32
-; RV64I-NEXT:    sub a0, a1, a0
-; RV64I-NEXT:    srai a0, a0, 32
+; RV64I-NEXT:    li a0, 1
+; RV64I-NEXT:    subw a0, a0, a1
 ; RV64I-NEXT:    ret
   %3 = mul i64 %1, -4294967296
   %4 = add i64 %3, 4294967296
@@ -119,11 +112,10 @@ define i64 @test8(i32* %0, i64 %1) {
 define signext i32 @test9(i32* %0, i64 %1) {
 ; RV64I-LABEL: test9:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    slli a1, a1, 32
-; RV64I-NEXT:    lui a2, 4097
-; RV64I-NEXT:    slli a2, a2, 20
-; RV64I-NEXT:    add a1, a1, a2
-; RV64I-NEXT:    srai a1, a1, 30
+; RV64I-NEXT:    lui a2, 1
+; RV64I-NEXT:    addiw a2, a2, 1
+; RV64I-NEXT:    addw a1, a1, a2
+; RV64I-NEXT:    slli a1, a1, 2
 ; RV64I-NEXT:    add a0, a0, a1
 ; RV64I-NEXT:    lw a0, 0(a0)
 ; RV64I-NEXT:    ret
@@ -140,12 +132,10 @@ define signext i32 @test9(i32* %0, i64 %1) {
 define signext i32 @test10(i32* %0, i64 %1) {
 ; RV64I-LABEL: test10:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    slli a1, a1, 32
 ; RV64I-NEXT:    lui a2, 30141
 ; RV64I-NEXT:    addiw a2, a2, -747
-; RV64I-NEXT:    slli a2, a2, 32
-; RV64I-NEXT:    sub a1, a2, a1
-; RV64I-NEXT:    srai a1, a1, 30
+; RV64I-NEXT:    subw a1, a2, a1
+; RV64I-NEXT:    slli a1, a1, 2
 ; RV64I-NEXT:    add a0, a0, a1
 ; RV64I-NEXT:    lw a0, 0(a0)
 ; RV64I-NEXT:    ret
@@ -160,11 +150,8 @@ define signext i32 @test10(i32* %0, i64 %1) {
 define i64 @test11(i32* %0, i64 %1) {
 ; RV64I-LABEL: test11:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    slli a0, a1, 32
-; RV64I-NEXT:    li a1, -1
-; RV64I-NEXT:    slli a1, a1, 63
-; RV64I-NEXT:    sub a0, a1, a0
-; RV64I-NEXT:    srai a0, a0, 32
+; RV64I-NEXT:    lui a0, 524288
+; RV64I-NEXT:    subw a0, a0, a1
 ; RV64I-NEXT:    ret
   %3 = mul i64 %1, -4294967296
   %4 = add i64 %3, 9223372036854775808 ;0x8000'0000'0000'0000


        


More information about the llvm-commits mailing list