[llvm] [RISCV] Add DAG combine to turn (sub (shl X, 8-Y), (shr X, Y)) into orc.b (PR #111828)

Daniel Mokeev via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 09:12:17 PDT 2024


https://github.com/damokeev updated https://github.com/llvm/llvm-project/pull/111828

>From 78eb2d8d28ed2bf679ef4f25e40237d4d5ae6483 Mon Sep 17 00:00:00 2001
From: Daniel Mokeev <mokeev.gh at gmail.com>
Date: Wed, 9 Oct 2024 18:14:50 +0200
Subject: [PATCH 1/3] first version of orc b opts

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  72 +++-
 llvm/test/CodeGen/RISCV/orc-b-patterns.ll   | 345 ++++++++++++++++++++
 2 files changed, 414 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/orc-b-patterns.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 01fa418e4dbdf4..1619b684dd5818 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -23,12 +23,14 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13572,9 +13574,71 @@ static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
 }
 
+// Looks for (sub (shl X, 8-N), (shr X, N)) where the N-th bit in each byte is potentially set. Replace with orc.b. 
+static SDValue combineSubShiftToOrcBGeneralized(SDNode *N, SelectionDAG &DAG,
+                                     const RISCVSubtarget &Subtarget) {
+  if (!Subtarget.hasStdExtZbb())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+
+  if (VT != Subtarget.getXLenVT() && VT != MVT::i32 && VT != MVT::i16)
+    return SDValue(); 
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  
+  if (N0->getOpcode() != ISD::SHL){
+    return SDValue();
+  }  
+  
+  auto *ShAmtCLeft = dyn_cast<ConstantSDNode>(N0.getOperand(1));
+  if (!ShAmtCLeft)
+    return SDValue();
+  unsigned ShiftedAmount = 8 - ShAmtCLeft->getZExtValue();
+  SDValue LeftShiftOperand = N0->getOperand(0);
+  SDValue RightShiftOperand;
+
+  if (N1->getOpcode() == ISD::SRL){ // (sub (shl X, 8 - N), (srl X, N)) case
+    auto *ShAmtCRight = dyn_cast<ConstantSDNode>(N1.getOperand(1));
+    // Note that the (sub (X, (shr X, 8))) is a degenerate case that should not get optimized,
+    // as we would be replacing a subtraction with an orc.b
+    // (!N0.hasOneUse() && !N1.hasOneUse())
+    if (!ShAmtCRight || ShAmtCRight->getZExtValue() == 8 || ShAmtCRight->getZExtValue() != ShiftedAmount )
+    {
+        return SDValue();
+    }
+    if (!N0.hasOneUse() && !N1.hasOneUse()){
+      dbgs() << "Both operands both have > 1 use\n";
+      return SDValue();
+    }
+    RightShiftOperand = N1.getOperand(0);
+  }
+  else{ // (sub (shl X, 8), X) case
+    if (!N0.hasOneUse()){
+      dbgs() << "N0 has > 1 uses\n";
+      return SDValue();
+    }
+    // llvm::errs() << "N0 has " << N0->get;
+    RightShiftOperand = N1;
+  }
+
+  APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0x1));
+  Mask <<= ShiftedAmount;
+  // Check that X has indeed the right shape (only the N-th bit can be set in every byte)
+  if(!DAG.MaskedValueIsZero(LeftShiftOperand, ~Mask))
+    return SDValue();
+
+  if (LeftShiftOperand != RightShiftOperand)
+    return SDValue();
+  dbgs() << "Optimized for node - " << N->getDebugLoc() << "\n";
+  return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, LeftShiftOperand);
+}
+
+
 // Looks for (sub (shl X, 8), X) where only bits 8, 16, 24, 32, etc. of X are
 // non-zero. Replace with orc.b.
-static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
+__attribute__((unused)) static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
                                      const RISCVSubtarget &Subtarget) {
   if (!Subtarget.hasStdExtZbb())
     return SDValue();
@@ -13596,7 +13660,7 @@ static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
 
   APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0xfe));
   if (!DAG.MaskedValueIsZero(N1, Mask))
-    return SDValue();
+    return SDValue(); 
 
   return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, N1);
 }
@@ -13623,7 +13687,9 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
 
   if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
-  if (SDValue V = combineSubShiftToOrcB(N, DAG, Subtarget))
+  // if (SDValue V = combineSubShiftToOrcB(N, DAG, Subtarget))
+  //   return V;
+  if (SDValue V  = combineSubShiftToOrcBGeneralized(N, DAG, Subtarget))
     return V;
 
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
diff --git a/llvm/test/CodeGen/RISCV/orc-b-patterns.ll b/llvm/test/CodeGen/RISCV/orc-b-patterns.ll
new file mode 100644
index 00000000000000..b7c67cf8930133
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/orc-b-patterns.ll
@@ -0,0 +1,345 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck %s -check-prefixes=CHECK,RV32I
+; RUN: llc -mtriple=riscv32 -mattr=+zbb -verify-machineinstrs < %s \
+; RUN:   | FileCheck %s -check-prefixes=CHECK,RV32ZBB
+
+define i32 @orc_b_i32_mul255(i32 %x) nounwind {
+; RV32I-LABEL: orc_b_i32_mul255:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a1, 4112
+; RV32I-NEXT:    addi a1, a1, 257
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 8
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_mul255:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a1, 4112
+; RV32ZBB-NEXT:    addi a1, a1, 257
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 16843009
+  %mul = mul nuw nsw i32 %and, 255
+  ret i32 %mul
+}
+
+
+define i32 @orc_b_i32_sub_shl8x_x_lsb(i32  %x)  {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_lsb:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a1, 4112
+; RV32I-NEXT:    addi a1, a1, 257
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 8
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_lsb:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a1, 4112
+; RV32ZBB-NEXT:    addi a1, a1, 257
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 16843009
+  %sub = mul nuw i32 %and, 255
+  ret i32 %sub
+}
+
+define  i32 @orc_b_i32_sub_shl8x_x_b1(i32  %x)  {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b1:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a1, 8224
+; RV32I-NEXT:    addi a1, a1, 514
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 7
+; RV32I-NEXT:    srli a0, a0, 1
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_b1:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a1, 8224
+; RV32ZBB-NEXT:    addi a1, a1, 514
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 33686018
+  %shl = shl i32 %and, 7
+  %shr = lshr exact i32 %and, 1
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define  i32 @orc_b_i32_sub_shl8x_x_b2(i32  %x)  {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b2:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a1, 16448
+; RV32I-NEXT:    addi a1, a1, 1028
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 6
+; RV32I-NEXT:    srli a0, a0, 2
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_b2:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a1, 16448
+; RV32ZBB-NEXT:    addi a1, a1, 1028
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 67372036
+  %shl = shl i32 %and, 6
+  %shr = lshr exact i32 %and, 2
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define i32 @orc_b_i32_sub_shl8x_x_b3(i32  %x)  {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b3:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 24672
+; CHECK-NEXT:    addi a1, a1, 1542
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 5
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 101058054
+  %shl = shl nuw i32 %and, 5
+  %shr = lshr i32 %and, 3
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define  i32 @orc_b_i32_sub_shl8x_x_b4(i32  %x)  {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b4:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 32897
+; CHECK-NEXT:    addi a1, a1, -2040
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 4
+; CHECK-NEXT:    srli a0, a0, 4
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 134744072
+  %shl = shl nuw i32 %and, 4
+  %shr = lshr i32 %and, 4
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define  i32 @orc_b_i32_sub_shl8x_x_b5(i32  %x)  {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b5:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 65793
+; CHECK-NEXT:    addi a1, a1, 16
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 3
+; CHECK-NEXT:    srli a0, a0, 5
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 269488144
+  %shl = shl nuw i32 %and, 3
+  %shr = lshr i32 %and, 5
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define i32 @orc_b_i32_sub_shl8x_x_b6(i32 %x)  {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b6:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 131586
+; CHECK-NEXT:    addi a1, a1, 32
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 2
+; CHECK-NEXT:    srli a0, a0, 6
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 538976288
+  %shl = shl nuw i32 %and, 2
+  %shr = lshr i32 %and, 6
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define i32 @orc_b_i32_sub_shl8x_x_b7(i32 %x)  {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b7:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 263172
+; CHECK-NEXT:    addi a1, a1, 64
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 1
+; CHECK-NEXT:    srli a0, a0, 7
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 1077952576
+  %shl = shl nuw i32 %and, 1
+  %shr = lshr i32 %and, 7
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+define i32 @orc_b_i32_sub_shl8x_x_b1_shl_used(i32 %x, ptr %arr) {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b1_shl_used:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a2, 8224
+; RV32I-NEXT:    addi a2, a2, 514
+; RV32I-NEXT:    and a0, a0, a2
+; RV32I-NEXT:    slli a2, a0, 7
+; RV32I-NEXT:    srli a3, a0, 1
+; RV32I-NEXT:    sub a0, a2, a3
+; RV32I-NEXT:    sw a3, 0(a1)
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_b1_shl_used:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a2, 8224
+; RV32ZBB-NEXT:    addi a2, a2, 514
+; RV32ZBB-NEXT:    and a0, a0, a2
+; RV32ZBB-NEXT:    srli a2, a0, 1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    sw a2, 0(a1)
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 33686018
+  %shl = shl i32 %and, 7
+  %shr = lshr exact i32 %and, 1
+  store i32 %shr, ptr %arr, align 4
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+define i32 @orc_b_i32_sub_shl8x_x_b1_srl_used(i32  %x, ptr %arr) {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b1_srl_used:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a2, 8224
+; RV32I-NEXT:    addi a2, a2, 514
+; RV32I-NEXT:    and a0, a0, a2
+; RV32I-NEXT:    slli a2, a0, 7
+; RV32I-NEXT:    srli a0, a0, 1
+; RV32I-NEXT:    sub a0, a2, a0
+; RV32I-NEXT:    sw a2, 0(a1)
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_b1_srl_used:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a2, 8224
+; RV32ZBB-NEXT:    addi a2, a2, 514
+; RV32ZBB-NEXT:    and a0, a0, a2
+; RV32ZBB-NEXT:    slli a2, a0, 7
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    sw a2, 0(a1)
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 33686018
+  %shl = shl i32 %and, 7
+  %shr = lshr exact i32 %and, 1
+  store i32 %shl, ptr %arr, align 4
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define i32 @orc_b_i32_sub_shl8x_x_b1_not_used(i32  %x, ptr %arr) {
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b1_not_used:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    lui a1, 8224
+; RV32I-NEXT:    addi a1, a1, 514
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 7
+; RV32I-NEXT:    srli a0, a0, 1
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_b1_not_used:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    lui a1, 8224
+; RV32ZBB-NEXT:    addi a1, a1, 514
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %and = and i32 %x, 33686018
+  %shl = shl i32 %and, 7
+  %shr = lshr exact i32 %and, 1
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+define i32 @orc_b_i32_sub_shl8x_x_shl_used(i32  %x, ptr %arr){
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_shl_used:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a2, 4112
+; CHECK-NEXT:    addi a2, a2, 257
+; CHECK-NEXT:    and a0, a0, a2
+; CHECK-NEXT:    slli a2, a0, 8
+; CHECK-NEXT:    sub a0, a2, a0
+; CHECK-NEXT:    sw a2, 0(a1)
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 16843009
+  %shl = shl i32 %and, 8
+  store i32 %shl, ptr %arr, align 4
+  %sub = mul nuw i32 %and, 255
+  ret i32 %sub
+}
+
+define i32 @orc_b_i32_sub_shl8x_x_b1_both_used(i32  %x, ptr %arr) {
+; CHECK-LABEL: orc_b_i32_sub_shl8x_x_b1_both_used:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a2, 8224
+; CHECK-NEXT:    addi a2, a2, 514
+; CHECK-NEXT:    and a0, a0, a2
+; CHECK-NEXT:    slli a2, a0, 7
+; CHECK-NEXT:    srli a3, a0, 1
+; CHECK-NEXT:    sw a2, 0(a1)
+; CHECK-NEXT:    sub a0, a2, a3
+; CHECK-NEXT:    sw a3, 4(a1)
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 33686018
+  %shl = shl i32 %and, 7
+  %shr = lshr exact i32 %and, 1
+  store i32 %shl, ptr %arr, align 4
+  %arrayidx1 = getelementptr inbounds i8, ptr %arr, i32 4
+  store i32 %shr, ptr %arrayidx1, align 4
+  %sub = sub nsw i32 %shl, %shr
+  ret i32 %sub
+}
+
+
+define i32 @orc_b_i32_sub_x_shr8x(i32 %x)  {
+; CHECK-LABEL: orc_b_i32_sub_x_shr8x:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 4112
+; CHECK-NEXT:    addi a1, a1, 257
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    srli a1, a0, 8
+; CHECK-NEXT:    sub a0, a0, a1
+; CHECK-NEXT:    ret
+entry:
+  %and = and i32 %x, 16843009
+  %shr = lshr i32 %and, 8
+  %sub = sub nsw i32 %and, %shr
+  ret i32 %sub
+}

>From 3141c5dee88dc74b171940a713b2528220fcca6b Mon Sep 17 00:00:00 2001
From: Daniel Mokeev <mokeev.gh at gmail.com>
Date: Thu, 10 Oct 2024 13:43:53 +0200
Subject: [PATCH 2/3] [RISCV] Add DAG combine to turn (sub (shl X, 8-Y), (shr
 X, Y)) into orc.b

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 85 ++++++---------------
 llvm/test/CodeGen/RISCV/orc-b-patterns.ll   | 27 +++++++
 2 files changed, 50 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1619b684dd5818..ee948ca71a31c9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13574,8 +13574,11 @@ static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
 }
 
-// Looks for (sub (shl X, 8-N), (shr X, N)) where the N-th bit in each byte is potentially set. Replace with orc.b. 
-static SDValue combineSubShiftToOrcBGeneralized(SDNode *N, SelectionDAG &DAG,
+// Looks for (sub (shl X, 8-Y), (shr X, Y)) where the Y-th bit in each byte is
+// potentially set. It is fine for Y to be 0, meaning that (sub (shl X, 8), X)
+// is also valid. Replace with (orc.b X). For example, 0b0000_1000_0000_1000 is
+// valid with Y=3, while 0b0000_1000_0000_0100 is not.
+static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
                                      const RISCVSubtarget &Subtarget) {
   if (!Subtarget.hasStdExtZbb())
     return SDValue();
@@ -13583,86 +13586,46 @@ static SDValue combineSubShiftToOrcBGeneralized(SDNode *N, SelectionDAG &DAG,
   EVT VT = N->getValueType(0);
 
   if (VT != Subtarget.getXLenVT() && VT != MVT::i32 && VT != MVT::i16)
-    return SDValue(); 
+    return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  
-  if (N0->getOpcode() != ISD::SHL){
+
+  if (N0->getOpcode() != ISD::SHL)
     return SDValue();
-  }  
   
   auto *ShAmtCLeft = dyn_cast<ConstantSDNode>(N0.getOperand(1));
   if (!ShAmtCLeft)
     return SDValue();
   unsigned ShiftedAmount = 8 - ShAmtCLeft->getZExtValue();
   SDValue LeftShiftOperand = N0->getOperand(0);
-  SDValue RightShiftOperand;
+  SDValue RightShiftOperand = N1;
+
+  if (ShiftedAmount != 0 && N1->getOpcode() != ISD::SRL)
+    return SDValue();
 
-  if (N1->getOpcode() == ISD::SRL){ // (sub (shl X, 8 - N), (srl X, N)) case
+  if (ShiftedAmount != 0) { // Right operand must be a right shift.
     auto *ShAmtCRight = dyn_cast<ConstantSDNode>(N1.getOperand(1));
-    // Note that the (sub (X, (shr X, 8))) is a degenerate case that should not get optimized,
-    // as we would be replacing a subtraction with an orc.b
-    // (!N0.hasOneUse() && !N1.hasOneUse())
-    if (!ShAmtCRight || ShAmtCRight->getZExtValue() == 8 || ShAmtCRight->getZExtValue() != ShiftedAmount )
-    {
-        return SDValue();
-    }
-    if (!N0.hasOneUse() && !N1.hasOneUse()){
-      dbgs() << "Both operands both have > 1 use\n";
+    if (!ShAmtCRight || ShAmtCRight->getZExtValue() != ShiftedAmount)
       return SDValue();
-    }
     RightShiftOperand = N1.getOperand(0);
   }
-  else{ // (sub (shl X, 8), X) case
-    if (!N0.hasOneUse()){
-      dbgs() << "N0 has > 1 uses\n";
-      return SDValue();
-    }
-    // llvm::errs() << "N0 has " << N0->get;
-    RightShiftOperand = N1;
-  }
 
-  APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0x1));
-  Mask <<= ShiftedAmount;
-  // Check that X has indeed the right shape (only the N-th bit can be set in every byte)
-  if(!DAG.MaskedValueIsZero(LeftShiftOperand, ~Mask))
+  // At least one shift should have a single use.
+  if (!N0.hasOneUse() && (ShiftedAmount == 0 || !N1.hasOneUse()))
     return SDValue();
 
   if (LeftShiftOperand != RightShiftOperand)
     return SDValue();
-  dbgs() << "Optimized for node - " << N->getDebugLoc() << "\n";
-  return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, LeftShiftOperand);
-}
-
-
-// Looks for (sub (shl X, 8), X) where only bits 8, 16, 24, 32, etc. of X are
-// non-zero. Replace with orc.b.
-__attribute__((unused)) static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
-                                     const RISCVSubtarget &Subtarget) {
-  if (!Subtarget.hasStdExtZbb())
-    return SDValue();
-
-  EVT VT = N->getValueType(0);
 
-  if (VT != Subtarget.getXLenVT() && VT != MVT::i32 && VT != MVT::i16)
-    return SDValue();
-
-  SDValue N0 = N->getOperand(0);
-  SDValue N1 = N->getOperand(1);
-
-  if (N0.getOpcode() != ISD::SHL || N0.getOperand(0) != N1 || !N0.hasOneUse())
-    return SDValue();
-
-  auto *ShAmtC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
-  if (!ShAmtC || ShAmtC->getZExtValue() != 8)
+  APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0x1));
+  Mask <<= ShiftedAmount;
+  // Check that X has indeed the right shape (only the Y-th bit can be set in
+  // every byte).
+  if (!DAG.MaskedValueIsZero(LeftShiftOperand, ~Mask))
     return SDValue();
 
-  APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0xfe));
-  if (!DAG.MaskedValueIsZero(N1, Mask))
-    return SDValue(); 
-
-  return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, N1);
+  return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, LeftShiftOperand);
 }
 
 static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
@@ -13687,9 +13650,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
 
   if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
-  // if (SDValue V = combineSubShiftToOrcB(N, DAG, Subtarget))
-  //   return V;
-  if (SDValue V  = combineSubShiftToOrcBGeneralized(N, DAG, Subtarget))
+  if (SDValue V = combineSubShiftToOrcB(N, DAG, Subtarget))
     return V;
 
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
diff --git a/llvm/test/CodeGen/RISCV/orc-b-patterns.ll b/llvm/test/CodeGen/RISCV/orc-b-patterns.ll
index b7c67cf8930133..184e66c14b33fc 100644
--- a/llvm/test/CodeGen/RISCV/orc-b-patterns.ll
+++ b/llvm/test/CodeGen/RISCV/orc-b-patterns.ll
@@ -51,6 +51,33 @@ entry:
   ret i32 %sub
 }
 
+define i32 @orc_b_i32_sub_shl8x_x_lsb_preshifted(i32 %x){
+; RV32I-LABEL: orc_b_i32_sub_shl8x_x_lsb_preshifted:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    srli a0, a0, 11
+; RV32I-NEXT:    lui a1, 16
+; RV32I-NEXT:    addi a1, a1, 257
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 8
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32_sub_shl8x_x_lsb_preshifted:
+; RV32ZBB:       # %bb.0: # %entry
+; RV32ZBB-NEXT:    srli a0, a0, 11
+; RV32ZBB-NEXT:    lui a1, 16
+; RV32ZBB-NEXT:    addi a1, a1, 257
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+entry:
+  %shr = lshr i32 %x, 11
+  %and = and i32 %shr, 16843009
+  %sub = mul nuw i32 %and, 255
+  ret i32 %sub
+}
+
+
 define  i32 @orc_b_i32_sub_shl8x_x_b1(i32  %x)  {
 ; RV32I-LABEL: orc_b_i32_sub_shl8x_x_b1:
 ; RV32I:       # %bb.0: # %entry

>From ef833f4e348177dca998fe2b4697e6595736b348 Mon Sep 17 00:00:00 2001
From: Daniel Mokeev <mokeev.gh at gmail.com>
Date: Thu, 10 Oct 2024 18:11:59 +0200
Subject: [PATCH 3/3] Fix typo, remove unused includes and add early for
 correctness

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ee948ca71a31c9..57010ab1adb3dd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -23,14 +23,12 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
-#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
-#include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13598,6 +13596,10 @@ static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
   if (!ShAmtCLeft)
     return SDValue();
   unsigned ShiftedAmount = 8 - ShAmtCLeft->getZExtValue();
+
+  if (ShiftedAmount >= 8)
+    return SDValue();
+
   SDValue LeftShiftOperand = N0->getOperand(0);
   SDValue RightShiftOperand = N1;
 



More information about the llvm-commits mailing list