[llvm] 4830fa1 - [RISCV] Make sure we always call tryShrinkShlLogicImm for ISD:AND during isel.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 22 14:30:19 PDT 2022


Author: Craig Topper
Date: 2022-10-22T14:30:13-07:00
New Revision: 4830fa18aac6950d479a413c995c38fff56ac42c

URL: https://github.com/llvm/llvm-project/commit/4830fa18aac6950d479a413c995c38fff56ac42c
DIFF: https://github.com/llvm/llvm-project/commit/4830fa18aac6950d479a413c995c38fff56ac42c.diff

LOG: [RISCV] Make sure we always call tryShrinkShlLogicImm for ISD:AND during isel.

There was an early out that prevented us from calling this for
(and (sext_inreg (shl X, C1), i32), C2).

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/test/CodeGen/RISCV/narrow-shl-cst.ll
    llvm/test/CodeGen/RISCV/shift-and.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 1ff1ea138e06a..d62098d45717c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -858,178 +858,181 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     SDValue N0 = Node->getOperand(0);
 
     bool LeftShift = N0.getOpcode() == ISD::SHL;
-    if (!LeftShift && N0.getOpcode() != ISD::SRL)
-      break;
+    if (LeftShift || N0.getOpcode() == ISD::SRL) {
+      auto *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
+      if (!C)
+        break;
+      unsigned C2 = C->getZExtValue();
+      unsigned XLen = Subtarget->getXLen();
+      assert((C2 > 0 && C2 < XLen) && "Unexpected shift amount!");
 
-    auto *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
-    if (!C)
-      break;
-    unsigned C2 = C->getZExtValue();
-    unsigned XLen = Subtarget->getXLen();
-    assert((C2 > 0 && C2 < XLen) && "Unexpected shift amount!");
+      uint64_t C1 = N1C->getZExtValue();
 
-    uint64_t C1 = N1C->getZExtValue();
+      // Keep track of whether this is a c.andi. If we can't use c.andi, the
+      // shift pair might offer more compression opportunities.
+      // TODO: We could check for C extension here, but we don't have many lit
+      // tests with the C extension enabled so not checking gets better
+      // coverage.
+      // TODO: What if ANDI faster than shift?
+      bool IsCANDI = isInt<6>(N1C->getSExtValue());
 
-    // Keep track of whether this is a c.andi. If we can't use c.andi, the
-    // shift pair might offer more compression opportunities.
-    // TODO: We could check for C extension here, but we don't have many lit
-    // tests with the C extension enabled so not checking gets better coverage.
-    // TODO: What if ANDI faster than shift?
-    bool IsCANDI = isInt<6>(N1C->getSExtValue());
-
-    // Clear irrelevant bits in the mask.
-    if (LeftShift)
-      C1 &= maskTrailingZeros<uint64_t>(C2);
-    else
-      C1 &= maskTrailingOnes<uint64_t>(XLen - C2);
-
-    // Some transforms should only be done if the shift has a single use or
-    // the AND would become (srli (slli X, 32), 32)
-    bool OneUseOrZExtW = N0.hasOneUse() || C1 == UINT64_C(0xFFFFFFFF);
-
-    SDValue X = N0.getOperand(0);
-
-    // Turn (and (srl x, c2) c1) -> (srli (slli x, c3-c2), c3) if c1 is a mask
-    // with c3 leading zeros.
-    if (!LeftShift && isMask_64(C1)) {
-      unsigned Leading = XLen - (64 - countLeadingZeros(C1));
-      if (C2 < Leading) {
-        // If the number of leading zeros is C2+32 this can be SRLIW.
-        if (C2 + 32 == Leading) {
-          SDNode *SRLIW = CurDAG->getMachineNode(
-              RISCV::SRLIW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT));
-          ReplaceNode(Node, SRLIW);
-          return;
-        }
+      // Clear irrelevant bits in the mask.
+      if (LeftShift)
+        C1 &= maskTrailingZeros<uint64_t>(C2);
+      else
+        C1 &= maskTrailingOnes<uint64_t>(XLen - C2);
+
+      // Some transforms should only be done if the shift has a single use or
+      // the AND would become (srli (slli X, 32), 32)
+      bool OneUseOrZExtW = N0.hasOneUse() || C1 == UINT64_C(0xFFFFFFFF);
+
+      SDValue X = N0.getOperand(0);
+
+      // Turn (and (srl x, c2) c1) -> (srli (slli x, c3-c2), c3) if c1 is a mask
+      // with c3 leading zeros.
+      if (!LeftShift && isMask_64(C1)) {
+        unsigned Leading = XLen - (64 - countLeadingZeros(C1));
+        if (C2 < Leading) {
+          // If the number of leading zeros is C2+32 this can be SRLIW.
+          if (C2 + 32 == Leading) {
+            SDNode *SRLIW = CurDAG->getMachineNode(
+                RISCV::SRLIW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT));
+            ReplaceNode(Node, SRLIW);
+            return;
+          }
 
-        // (and (srl (sexti32 Y), c2), c1) -> (srliw (sraiw Y, 31), c3 - 32) if
-        // c1 is a mask with c3 leading zeros and c2 >= 32 and c3-c2==1.
-        //
-        // This pattern occurs when (i32 (srl (sra 31), c3 - 32)) is type
-        // legalized and goes through DAG combine.
-        if (C2 >= 32 && (Leading - C2) == 1 && N0.hasOneUse() &&
-            X.getOpcode() == ISD::SIGN_EXTEND_INREG &&
-            cast<VTSDNode>(X.getOperand(1))->getVT() == MVT::i32) {
-          SDNode *SRAIW =
-              CurDAG->getMachineNode(RISCV::SRAIW, DL, VT, X.getOperand(0),
-                                     CurDAG->getTargetConstant(31, DL, VT));
-          SDNode *SRLIW = CurDAG->getMachineNode(
-              RISCV::SRLIW, DL, VT, SDValue(SRAIW, 0),
-              CurDAG->getTargetConstant(Leading - 32, DL, VT));
-          ReplaceNode(Node, SRLIW);
-          return;
+          // (and (srl (sexti32 Y), c2), c1) -> (srliw (sraiw Y, 31), c3 - 32)
+          // if c1 is a mask with c3 leading zeros and c2 >= 32 and c3-c2==1.
+          //
+          // This pattern occurs when (i32 (srl (sra 31), c3 - 32)) is type
+          // legalized and goes through DAG combine.
+          if (C2 >= 32 && (Leading - C2) == 1 && N0.hasOneUse() &&
+              X.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+              cast<VTSDNode>(X.getOperand(1))->getVT() == MVT::i32) {
+            SDNode *SRAIW =
+                CurDAG->getMachineNode(RISCV::SRAIW, DL, VT, X.getOperand(0),
+                                       CurDAG->getTargetConstant(31, DL, VT));
+            SDNode *SRLIW = CurDAG->getMachineNode(
+                RISCV::SRLIW, DL, VT, SDValue(SRAIW, 0),
+                CurDAG->getTargetConstant(Leading - 32, DL, VT));
+            ReplaceNode(Node, SRLIW);
+            return;
+          }
+
+          // (srli (slli x, c3-c2), c3).
+          // Skip if we could use (zext.w (sraiw X, C2)).
+          bool Skip = Subtarget->hasStdExtZba() && Leading == 32 &&
+                      X.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+                      cast<VTSDNode>(X.getOperand(1))->getVT() == MVT::i32;
+          // Also Skip if we can use bexti.
+          Skip |= Subtarget->hasStdExtZbs() && Leading == XLen - 1;
+          if (OneUseOrZExtW && !Skip) {
+            SDNode *SLLI = CurDAG->getMachineNode(
+                RISCV::SLLI, DL, VT, X,
+                CurDAG->getTargetConstant(Leading - C2, DL, VT));
+            SDNode *SRLI = CurDAG->getMachineNode(
+                RISCV::SRLI, DL, VT, SDValue(SLLI, 0),
+                CurDAG->getTargetConstant(Leading, DL, VT));
+            ReplaceNode(Node, SRLI);
+            return;
+          }
         }
+      }
 
-        // (srli (slli x, c3-c2), c3).
-        // Skip if we could use (zext.w (sraiw X, C2)).
-        bool Skip = Subtarget->hasStdExtZba() && Leading == 32 &&
-                    X.getOpcode() == ISD::SIGN_EXTEND_INREG &&
-                    cast<VTSDNode>(X.getOperand(1))->getVT() == MVT::i32;
-        // Also Skip if we can use bexti.
-        Skip |= Subtarget->hasStdExtZbs() && Leading == XLen - 1;
-        if (OneUseOrZExtW && !Skip) {
-          SDNode *SLLI = CurDAG->getMachineNode(
-              RISCV::SLLI, DL, VT, X,
-              CurDAG->getTargetConstant(Leading - C2, DL, VT));
-          SDNode *SRLI = CurDAG->getMachineNode(
-              RISCV::SRLI, DL, VT, SDValue(SLLI, 0),
-              CurDAG->getTargetConstant(Leading, DL, VT));
-          ReplaceNode(Node, SRLI);
-          return;
+      // Turn (and (shl x, c2), c1) -> (srli (slli c2+c3), c3) if c1 is a mask
+      // shifted by c2 bits with c3 leading zeros.
+      if (LeftShift && isShiftedMask_64(C1)) {
+        unsigned Leading = XLen - (64 - countLeadingZeros(C1));
+
+        if (C2 + Leading < XLen &&
+            C1 == (maskTrailingOnes<uint64_t>(XLen - (C2 + Leading)) << C2)) {
+          // Use slli.uw when possible.
+          if ((XLen - (C2 + Leading)) == 32 && Subtarget->hasStdExtZba()) {
+            SDNode *SLLI_UW =
+                CurDAG->getMachineNode(RISCV::SLLI_UW, DL, VT, X,
+                                       CurDAG->getTargetConstant(C2, DL, VT));
+            ReplaceNode(Node, SLLI_UW);
+            return;
+          }
+
+          // (srli (slli c2+c3), c3)
+          if (OneUseOrZExtW && !IsCANDI) {
+            SDNode *SLLI = CurDAG->getMachineNode(
+                RISCV::SLLI, DL, VT, X,
+                CurDAG->getTargetConstant(C2 + Leading, DL, VT));
+            SDNode *SRLI = CurDAG->getMachineNode(
+                RISCV::SRLI, DL, VT, SDValue(SLLI, 0),
+                CurDAG->getTargetConstant(Leading, DL, VT));
+            ReplaceNode(Node, SRLI);
+            return;
+          }
         }
       }
-    }
 
-    // Turn (and (shl x, c2), c1) -> (srli (slli c2+c3), c3) if c1 is a mask
-    // shifted by c2 bits with c3 leading zeros.
-    if (LeftShift && isShiftedMask_64(C1)) {
-      unsigned Leading = XLen - (64 - countLeadingZeros(C1));
-
-      if (C2 + Leading < XLen &&
-          C1 == (maskTrailingOnes<uint64_t>(XLen - (C2 + Leading)) << C2)) {
-        // Use slli.uw when possible.
-        if ((XLen - (C2 + Leading)) == 32 && Subtarget->hasStdExtZba()) {
-          SDNode *SLLI_UW = CurDAG->getMachineNode(
-              RISCV::SLLI_UW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT));
-          ReplaceNode(Node, SLLI_UW);
+      // Turn (and (shr x, c2), c1) -> (slli (srli x, c2+c3), c3) if c1 is a
+      // shifted mask with c2 leading zeros and c3 trailing zeros.
+      if (!LeftShift && isShiftedMask_64(C1)) {
+        unsigned Leading = XLen - (64 - countLeadingZeros(C1));
+        unsigned Trailing = countTrailingZeros(C1);
+        if (Leading == C2 && C2 + Trailing < XLen && OneUseOrZExtW &&
+            !IsCANDI) {
+          unsigned SrliOpc = RISCV::SRLI;
+          // If the input is zexti32 we should use SRLIW.
+          if (X.getOpcode() == ISD::AND &&
+              isa<ConstantSDNode>(X.getOperand(1)) &&
+              X.getConstantOperandVal(1) == UINT64_C(0xFFFFFFFF)) {
+            SrliOpc = RISCV::SRLIW;
+            X = X.getOperand(0);
+          }
+          SDNode *SRLI = CurDAG->getMachineNode(
+              SrliOpc, DL, VT, X,
+              CurDAG->getTargetConstant(C2 + Trailing, DL, VT));
+          SDNode *SLLI = CurDAG->getMachineNode(
+              RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
+              CurDAG->getTargetConstant(Trailing, DL, VT));
+          ReplaceNode(Node, SLLI);
           return;
         }
-
-        // (srli (slli c2+c3), c3)
-        if (OneUseOrZExtW && !IsCANDI) {
+        // If the leading zero count is C2+32, we can use SRLIW instead of SRLI.
+        if (Leading > 32 && (Leading - 32) == C2 && C2 + Trailing < 32 &&
+            OneUseOrZExtW && !IsCANDI) {
+          SDNode *SRLIW = CurDAG->getMachineNode(
+              RISCV::SRLIW, DL, VT, X,
+              CurDAG->getTargetConstant(C2 + Trailing, DL, VT));
           SDNode *SLLI = CurDAG->getMachineNode(
-              RISCV::SLLI, DL, VT, X,
-              CurDAG->getTargetConstant(C2 + Leading, DL, VT));
-          SDNode *SRLI = CurDAG->getMachineNode(
-              RISCV::SRLI, DL, VT, SDValue(SLLI, 0),
-              CurDAG->getTargetConstant(Leading, DL, VT));
-          ReplaceNode(Node, SRLI);
+              RISCV::SLLI, DL, VT, SDValue(SRLIW, 0),
+              CurDAG->getTargetConstant(Trailing, DL, VT));
+          ReplaceNode(Node, SLLI);
           return;
         }
       }
-    }
 
-    // Turn (and (shr x, c2), c1) -> (slli (srli x, c2+c3), c3) if c1 is a
-    // shifted mask with c2 leading zeros and c3 trailing zeros.
-    if (!LeftShift && isShiftedMask_64(C1)) {
-      unsigned Leading = XLen - (64 - countLeadingZeros(C1));
-      unsigned Trailing = countTrailingZeros(C1);
-      if (Leading == C2 && C2 + Trailing < XLen && OneUseOrZExtW && !IsCANDI) {
-        unsigned SrliOpc = RISCV::SRLI;
-        // If the input is zexti32 we should use SRLIW.
-        if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
-            X.getConstantOperandVal(1) == UINT64_C(0xFFFFFFFF)) {
-          SrliOpc = RISCV::SRLIW;
-          X = X.getOperand(0);
+      // Turn (and (shl x, c2), c1) -> (slli (srli x, c3-c2), c3) if c1 is a
+      // shifted mask with no leading zeros and c3 trailing zeros.
+      if (LeftShift && isShiftedMask_64(C1)) {
+        unsigned Leading = XLen - (64 - countLeadingZeros(C1));
+        unsigned Trailing = countTrailingZeros(C1);
+        if (Leading == 0 && C2 < Trailing && OneUseOrZExtW && !IsCANDI) {
+          SDNode *SRLI = CurDAG->getMachineNode(
+              RISCV::SRLI, DL, VT, X,
+              CurDAG->getTargetConstant(Trailing - C2, DL, VT));
+          SDNode *SLLI = CurDAG->getMachineNode(
+              RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
+              CurDAG->getTargetConstant(Trailing, DL, VT));
+          ReplaceNode(Node, SLLI);
+          return;
+        }
+        // If we have (32-C2) leading zeros, we can use SRLIW instead of SRLI.
+        if (C2 < Trailing && Leading + C2 == 32 && OneUseOrZExtW && !IsCANDI) {
+          SDNode *SRLIW = CurDAG->getMachineNode(
+              RISCV::SRLIW, DL, VT, X,
+              CurDAG->getTargetConstant(Trailing - C2, DL, VT));
+          SDNode *SLLI = CurDAG->getMachineNode(
+              RISCV::SLLI, DL, VT, SDValue(SRLIW, 0),
+              CurDAG->getTargetConstant(Trailing, DL, VT));
+          ReplaceNode(Node, SLLI);
+          return;
         }
-        SDNode *SRLI = CurDAG->getMachineNode(
-            SrliOpc, DL, VT, X,
-            CurDAG->getTargetConstant(C2 + Trailing, DL, VT));
-        SDNode *SLLI =
-            CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
-                                   CurDAG->getTargetConstant(Trailing, DL, VT));
-        ReplaceNode(Node, SLLI);
-        return;
-      }
-      // If the leading zero count is C2+32, we can use SRLIW instead of SRLI.
-      if (Leading > 32 && (Leading - 32) == C2 && C2 + Trailing < 32 &&
-          OneUseOrZExtW && !IsCANDI) {
-        SDNode *SRLIW = CurDAG->getMachineNode(
-            RISCV::SRLIW, DL, VT, X,
-            CurDAG->getTargetConstant(C2 + Trailing, DL, VT));
-        SDNode *SLLI =
-            CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLIW, 0),
-                                   CurDAG->getTargetConstant(Trailing, DL, VT));
-        ReplaceNode(Node, SLLI);
-        return;
-      }
-    }
-
-    // Turn (and (shl x, c2), c1) -> (slli (srli x, c3-c2), c3) if c1 is a
-    // shifted mask with no leading zeros and c3 trailing zeros.
-    if (LeftShift && isShiftedMask_64(C1)) {
-      unsigned Leading = XLen - (64 - countLeadingZeros(C1));
-      unsigned Trailing = countTrailingZeros(C1);
-      if (Leading == 0 && C2 < Trailing && OneUseOrZExtW && !IsCANDI) {
-        SDNode *SRLI = CurDAG->getMachineNode(
-            RISCV::SRLI, DL, VT, X,
-            CurDAG->getTargetConstant(Trailing - C2, DL, VT));
-        SDNode *SLLI =
-            CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
-                                   CurDAG->getTargetConstant(Trailing, DL, VT));
-        ReplaceNode(Node, SLLI);
-        return;
-      }
-      // If we have (32-C2) leading zeros, we can use SRLIW instead of SRLI.
-      if (C2 < Trailing && Leading + C2 == 32 && OneUseOrZExtW && !IsCANDI) {
-        SDNode *SRLIW = CurDAG->getMachineNode(
-            RISCV::SRLIW, DL, VT, X,
-            CurDAG->getTargetConstant(Trailing - C2, DL, VT));
-        SDNode *SLLI =
-            CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLIW, 0),
-                                   CurDAG->getTargetConstant(Trailing, DL, VT));
-        ReplaceNode(Node, SLLI);
-        return;
       }
     }
 

diff  --git a/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll b/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll
index bd99b1c44a57e..7434f620bef0e 100644
--- a/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll
+++ b/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll
@@ -189,9 +189,8 @@ define signext i32 @test11(i32 signext %x) nounwind {
 ;
 ; RV64-LABEL: test11:
 ; RV64:       # %bb.0:
+; RV64-NEXT:    andi a0, a0, -241
 ; RV64-NEXT:    slliw a0, a0, 17
-; RV64-NEXT:    lui a1, 1040864
-; RV64-NEXT:    and a0, a0, a1
 ; RV64-NEXT:    ret
   %or = shl i32 %x, 17
   %shl = and i32 %or, -31588352

diff  --git a/llvm/test/CodeGen/RISCV/shift-and.ll b/llvm/test/CodeGen/RISCV/shift-and.ll
index 288c2cd28c901..525ef624179c6 100644
--- a/llvm/test/CodeGen/RISCV/shift-and.ll
+++ b/llvm/test/CodeGen/RISCV/shift-and.ll
@@ -92,9 +92,8 @@ define i32 @test5(i32 %x) {
 ;
 ; RV64I-LABEL: test5:
 ; RV64I:       # %bb.0:
+; RV64I-NEXT:    andi a0, a0, -1024
 ; RV64I-NEXT:    slliw a0, a0, 6
-; RV64I-NEXT:    lui a1, 1048560
-; RV64I-NEXT:    and a0, a0, a1
 ; RV64I-NEXT:    ret
   %a = shl i32 %x, 6
   %b = and i32 %a, -65536


        


More information about the llvm-commits mailing list