[llvm] 36d8316 - [RISCV] Reduce duplicate code for calling SimplifyDemandedBits.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 19 07:18:03 PDT 2021


Author: Craig Topper
Date: 2021-08-19T07:09:38-07:00
New Revision: 36d8316cc8b6ab85d0d3ed46a04490afa5e49a29

URL: https://github.com/llvm/llvm-project/commit/36d8316cc8b6ab85d0d3ed46a04490afa5e49a29
DIFF: https://github.com/llvm/llvm-project/commit/36d8316cc8b6ab85d0d3ed46a04490afa5e49a29.diff

LOG: [RISCV] Reduce duplicate code for calling SimplifyDemandedBits.

This encapsulates the APInt creation and worklist management into
a helper function.

To keep one common interface I've use Log2_32 in places that
previously created a mask by subtracting 1 from a power of 2.

Differential Revision: https://reviews.llvm.org/D108324

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ec82e6ea927a..6659d26a8f1d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5968,6 +5968,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
 
+  // Helper to call SimplifyDemandedBits on an operand of N where only some low
+  // bits are demanded. N will be added to the Worklist if it was not deleted.
+  // Caller should return SDValue(N, 0) if this returns true.
+  auto SimplifyDemandedLowBitsHelper = [&](unsigned OpNo, unsigned LowBits) {
+    SDValue Op = N->getOperand(OpNo);
+    APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits);
+    if (!SimplifyDemandedBits(Op, Mask, DCI))
+      return false;
+
+    if (N->getOpcode() != ISD::DELETED_NODE)
+      DCI.AddToWorklist(N);
+    return true;
+  };
+
   switch (N->getOpcode()) {
   default:
     break;
@@ -6019,136 +6033,85 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::ROLW:
   case RISCVISD::RORW: {
     // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
-    SDValue LHS = N->getOperand(0);
-    SDValue RHS = N->getOperand(1);
-    APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
-    APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
-    if (SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI) ||
-        SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32) ||
+        SimplifyDemandedLowBitsHelper(1, 5))
       return SDValue(N, 0);
-    }
     break;
   }
   case RISCVISD::CLZW:
   case RISCVISD::CTZW: {
     // Only the lower 32 bits of the first operand are read
-    SDValue Op0 = N->getOperand(0);
-    APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
-    if (SimplifyDemandedBits(Op0, Mask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32))
       return SDValue(N, 0);
-    }
     break;
   }
   case RISCVISD::FSL:
   case RISCVISD::FSR: {
     // Only the lower log2(Bitwidth)+1 bits of the the shift amount are read.
-    SDValue ShAmt = N->getOperand(2);
-    unsigned BitWidth = ShAmt.getValueSizeInBits();
+    unsigned BitWidth = N->getOperand(2).getValueSizeInBits();
     assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
-    APInt ShAmtMask(BitWidth, (BitWidth * 2) - 1);
-    if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(2, Log2_32(BitWidth) + 1))
       return SDValue(N, 0);
-    }
     break;
   }
   case RISCVISD::FSLW:
   case RISCVISD::FSRW: {
     // Only the lower 32 bits of Values and lower 6 bits of shift amount are
     // read.
-    SDValue Op0 = N->getOperand(0);
-    SDValue Op1 = N->getOperand(1);
-    SDValue ShAmt = N->getOperand(2);
-    APInt OpMask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
-    APInt ShAmtMask = APInt::getLowBitsSet(ShAmt.getValueSizeInBits(), 6);
-    if (SimplifyDemandedBits(Op0, OpMask, DCI) ||
-        SimplifyDemandedBits(Op1, OpMask, DCI) ||
-        SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32) ||
+        SimplifyDemandedLowBitsHelper(1, 32) ||
+        SimplifyDemandedLowBitsHelper(2, 6))
       return SDValue(N, 0);
-    }
     break;
   }
   case RISCVISD::GREV:
   case RISCVISD::GORC: {
     // Only the lower log2(Bitwidth) bits of the the shift amount are read.
-    SDValue ShAmt = N->getOperand(1);
-    unsigned BitWidth = ShAmt.getValueSizeInBits();
+    unsigned BitWidth = N->getOperand(1).getValueSizeInBits();
     assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
-    APInt ShAmtMask(BitWidth, BitWidth - 1);
-    if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth)))
       return SDValue(N, 0);
-    }
 
     return combineGREVI_GORCI(N, DCI.DAG);
   }
   case RISCVISD::GREVW:
   case RISCVISD::GORCW: {
     // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
-    SDValue LHS = N->getOperand(0);
-    SDValue RHS = N->getOperand(1);
-    APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
-    APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
-    if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
-        SimplifyDemandedBits(RHS, RHSMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32) ||
+        SimplifyDemandedLowBitsHelper(1, 5))
       return SDValue(N, 0);
-    }
 
     return combineGREVI_GORCI(N, DCI.DAG);
   }
   case RISCVISD::SHFL:
   case RISCVISD::UNSHFL: {
-    // Only the lower log2(Bitwidth) bits of the the shift amount are read.
-    SDValue ShAmt = N->getOperand(1);
-    unsigned BitWidth = ShAmt.getValueSizeInBits();
+    // Only the lower log2(Bitwidth)-1 bits of the the shift amount are read.
+    unsigned BitWidth = N->getOperand(1).getValueSizeInBits();
     assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
-    APInt ShAmtMask(BitWidth, (BitWidth / 2) - 1);
-    if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth) - 1))
       return SDValue(N, 0);
-    }
 
     break;
   }
   case RISCVISD::SHFLW:
   case RISCVISD::UNSHFLW: {
-    // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
+    // Only the lower 32 bits of LHS and lower 4 bits of RHS are read.
     SDValue LHS = N->getOperand(0);
     SDValue RHS = N->getOperand(1);
     APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
     APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4);
-    if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
-        SimplifyDemandedBits(RHS, RHSMask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32) ||
+        SimplifyDemandedLowBitsHelper(1, 4))
       return SDValue(N, 0);
-    }
 
     break;
   }
   case RISCVISD::BCOMPRESSW:
   case RISCVISD::BDECOMPRESSW: {
     // Only the lower 32 bits of LHS and RHS are read.
-    SDValue LHS = N->getOperand(0);
-    SDValue RHS = N->getOperand(1);
-    APInt Mask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
-    if (SimplifyDemandedBits(LHS, Mask, DCI) ||
-        SimplifyDemandedBits(RHS, Mask, DCI)) {
-      if (N->getOpcode() != ISD::DELETED_NODE)
-        DCI.AddToWorklist(N);
+    if (SimplifyDemandedLowBitsHelper(0, 32) ||
+        SimplifyDemandedLowBitsHelper(1, 32))
       return SDValue(N, 0);
-    }
 
     break;
   }


        


More information about the llvm-commits mailing list