[libcxx-commits] [clang] [flang] [libc] [libcxx] [clang-tools-extra] [lldb] [lld] [libunwind] [llvm] [compiler-rt] [X86] Use RORX over SHR imm (PR #77964)

Bryce Wilson via libcxx-commits libcxx-commits at lists.llvm.org
Thu Jan 25 07:31:35 PST 2024


================
@@ -4216,6 +4217,97 @@ MachineSDNode *X86DAGToDAGISel::emitPCMPESTR(unsigned ROpc, unsigned MOpc,
   return CNode;
 }
 
+// When the consumer of a right shift (arithmetic or logical) wouldn't notice
+// the difference if the instruction was a rotate right instead (because the
+// bits shifted in are truncated away), the shift can be replaced by the RORX
+// instruction from BMI2. This doesn't set flags and can output to a different
+// register. However, this increases code size in most cases, and doesn't leave
+// the high bits in a useful state. There may be other situations where this
+// transformation is profitable given those conditions, but currently the
+// transformation is only made when it likely avoids spilling flags.
+bool X86DAGToDAGISel::rightShiftUncloberFlags(SDNode *N) {
+  EVT VT = N->getValueType(0);
+
+  // Target has to have BMI2 for RORX
+  if (!Subtarget->hasBMI2())
+    return false;
+
+  // Only handle scalar shifts.
+  if (VT.isVector())
+    return false;
+
+  unsigned OpSize;
+  if (VT == MVT::i64)
+    OpSize = 64;
+  else if (VT == MVT::i32)
+    OpSize = 32;
+  else if (VT == MVT::i16)
+    OpSize = 16;
+  else if (VT == MVT::i8)
+    return false; // i8 shift can't be truncated.
+  else
+    llvm_unreachable("Unexpected shift size");
+
+  unsigned TruncateSize = 0;
+  // This only works when the result is truncated.
+  for (const SDNode *User : N->uses()) {
+    auto name = User->getOperationName(CurDAG);
+    if (!User->isMachineOpcode() ||
+        User->getMachineOpcode() != TargetOpcode::EXTRACT_SUBREG)
+      return false;
+    EVT TuncateType = User->getValueType(0);
+    if (TuncateType == MVT::i32)
+      TruncateSize = std::max(TruncateSize, 32U);
+    else if (TuncateType == MVT::i16)
+      TruncateSize = std::max(TruncateSize, 16U);
+    else if (TuncateType == MVT::i8)
+      TruncateSize = std::max(TruncateSize, 8U);
+    else
+      return false;
+  }
+  if (TruncateSize >= OpSize)
+    return false;
+
+  // The shift must be by an immediate that wouldn't expose the zero or sign
+  // extended result.
+  auto *ShiftAmount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+  if (!ShiftAmount || ShiftAmount->getZExtValue() > OpSize - TruncateSize)
+    return false;
+
+  // Only make the replacement when it avoids clobbering used flags. This is a
+  // similar heuristic as used in the conversion to LEA, namely looking at the
+  // operand for an instruction that creates flags where those flags are used.
+  // This will have both false positives and false negatives. Ideally, both of
+  // these happen later on. Perhaps in copy to flags lowering or in register
+  // allocation.
+  bool MightClobberFlags = false;
+  SDNode *Input = N->getOperand(0).getNode();
+  for (auto Use : Input->uses()) {
+    if (Use->getOpcode() == ISD::CopyToReg) {
+      auto *RegisterNode =
+          dyn_cast<RegisterSDNode>(Use->getOperand(1).getNode());
+      if (RegisterNode && RegisterNode->getReg() == X86::EFLAGS) {
+        MightClobberFlags = true;
+        break;
+      }
+    }
+  }
+  if (!MightClobberFlags)
+    return false;
----------------
Bryce-MW wrote:

It should be correct? I've clarified the names / explanation a bit but it's possible that I got the logic wrong

https://github.com/llvm/llvm-project/pull/77964


More information about the libcxx-commits mailing list