[llvm] [RISCV] Add DAG combine to turn (sub (shl X, 8), X) into orc.b (PR #96680)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 12:04:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

If only bits 8, 16, 24, 32, etc. can be non-zero.

This is what (mul X, 255) is decomposed to. This decomposition happens early before RISC-V DAG combine runs.

This patch does not support types larger than XLen so i64 on i32 fails to generate 2 orc.b instructions. It might have worked if the mul hadn't been decomposed before it was expanded.

Partial fix for #<!-- -->96595.

---
Full diff: https://github.com/llvm/llvm-project/pull/96680.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+36-2) 
- (modified) llvm/test/CodeGen/RISCV/rv32zbb.ll (+61) 
- (modified) llvm/test/CodeGen/RISCV/rv64zbb.ll (+66) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7c7f167821234d..9e988783e2eb6c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12502,12 +12502,15 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     }
     break;
   }
-  case RISCVISD::BREV8: {
+  case RISCVISD::BREV8:
+  case RISCVISD::ORC_B: {
     MVT VT = N->getSimpleValueType(0);
     MVT XLenVT = Subtarget.getXLenVT();
     assert((VT == MVT::i16 || (VT == MVT::i32 && Subtarget.is64Bit())) &&
            "Unexpected custom legalisation");
-    assert(Subtarget.hasStdExtZbkb() && "Unexpected extension");
+    assert(((N->getOpcode() == RISCVISD::BREV8 && Subtarget.hasStdExtZbkb()) ||
+            (N->getOpcode() == RISCVISD::ORC_B && Subtarget.hasStdExtZbb())) &&
+           "Unexpected extension");
     SDValue NewOp = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, N->getOperand(0));
     SDValue NewRes = DAG.getNode(N->getOpcode(), DL, XLenVT, NewOp);
     // ReplaceNodeResults requires we maintain the same type for the return
@@ -13345,6 +13348,35 @@ static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
 }
 
+// Looks for (sub (shl X, 8), X) where only bits 8, 16, 24, 32, etc. of X are
+// non-zero. Replace with orc.b.
+static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
+                                     const RISCVSubtarget &Subtarget) {
+  if (!Subtarget.hasStdExtZbb())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+
+  if (VT != Subtarget.getXLenVT() && VT != MVT::i32 && VT != MVT::i16)
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+
+  if (N0.getOpcode() != ISD::SHL || N0.getOperand(0) != N1 || !N0.hasOneUse())
+    return SDValue();
+
+  auto *ShAmtC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
+  if (!ShAmtC || ShAmtC->getZExtValue() != 8)
+    return SDValue();
+
+  APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0xfe));
+  if (!DAG.MaskedValueIsZero(N1, Mask))
+    return SDValue();
+
+  return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, N1);
+}
+
 static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   if (SDValue V = combineSubOfBoolean(N, DAG))
@@ -13367,6 +13399,8 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
 
   if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
+  if (SDValue V = combineSubShiftToOrcB(N, DAG, Subtarget))
+    return V;
 
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
diff --git a/llvm/test/CodeGen/RISCV/rv32zbb.ll b/llvm/test/CodeGen/RISCV/rv32zbb.ll
index f25aa0de89da88..cb9fc6c16333e0 100644
--- a/llvm/test/CodeGen/RISCV/rv32zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zbb.ll
@@ -1356,3 +1356,64 @@ define i64 @bswap_i64(i64 %a) {
   %1 = call i64 @llvm.bswap.i64(i64 %a)
   ret i64 %1
 }
+
+define i16 @orc_b_i16(i16 %a) {
+; RV32I-LABEL: orc_b_i16:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    andi a0, a0, 257
+; RV32I-NEXT:    slli a1, a0, 8
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i16:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    andi a0, a0, 257
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+  %1 = and i16 %a, 257
+  %2 = mul nuw i16 %1, 255
+  ret i16 %2
+}
+
+define i32 @orc_b_i32(i32 %a) {
+; RV32I-LABEL: orc_b_i32:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    lui a1, 4112
+; RV32I-NEXT:    addi a1, a1, 257
+; RV32I-NEXT:    and a0, a0, a1
+; RV32I-NEXT:    slli a1, a0, 8
+; RV32I-NEXT:    sub a0, a1, a0
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: orc_b_i32:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    lui a1, 4112
+; RV32ZBB-NEXT:    addi a1, a1, 257
+; RV32ZBB-NEXT:    and a0, a0, a1
+; RV32ZBB-NEXT:    orc.b a0, a0
+; RV32ZBB-NEXT:    ret
+  %1 = and i32 %a, 16843009
+  %2 = mul nuw i32 %1, 255
+  ret i32 %2
+}
+
+define i64 @orc_b_i64(i64 %a) {
+; CHECK-LABEL: orc_b_i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a2, 4112
+; CHECK-NEXT:    addi a2, a2, 257
+; CHECK-NEXT:    and a1, a1, a2
+; CHECK-NEXT:    and a0, a0, a2
+; CHECK-NEXT:    slli a2, a0, 8
+; CHECK-NEXT:    sltu a3, a2, a0
+; CHECK-NEXT:    srli a4, a0, 24
+; CHECK-NEXT:    slli a5, a1, 8
+; CHECK-NEXT:    or a4, a5, a4
+; CHECK-NEXT:    sub a1, a4, a1
+; CHECK-NEXT:    sub a1, a1, a3
+; CHECK-NEXT:    sub a0, a2, a0
+; CHECK-NEXT:    ret
+  %1 = and i64 %a, 72340172838076673
+  %2 = mul nuw i64 %1, 255
+  ret i64 %2
+}
diff --git a/llvm/test/CodeGen/RISCV/rv64zbb.ll b/llvm/test/CodeGen/RISCV/rv64zbb.ll
index 4d5ef5db86057b..6c354cc1b446b2 100644
--- a/llvm/test/CodeGen/RISCV/rv64zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zbb.ll
@@ -1494,3 +1494,69 @@ define i64 @bswap_i64(i64 %a) {
   %1 = call i64 @llvm.bswap.i64(i64 %a)
   ret i64 %1
 }
+
+define i16 @orc_b_i16(i16 %a) {
+; RV64I-LABEL: orc_b_i16:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    andi a0, a0, 257
+; RV64I-NEXT:    slli a1, a0, 8
+; RV64I-NEXT:    sub a0, a1, a0
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: orc_b_i16:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    andi a0, a0, 257
+; RV64ZBB-NEXT:    orc.b a0, a0
+; RV64ZBB-NEXT:    ret
+  %1 = and i16 %a, 257
+  %2 = mul nuw i16 %1, 255
+  ret i16 %2
+}
+
+define i32 @orc_b_i32(i32 %a) {
+; RV64I-LABEL: orc_b_i32:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    lui a1, 4112
+; RV64I-NEXT:    addi a1, a1, 257
+; RV64I-NEXT:    and a0, a0, a1
+; RV64I-NEXT:    slli a1, a0, 8
+; RV64I-NEXT:    subw a0, a1, a0
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: orc_b_i32:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    lui a1, 4112
+; RV64ZBB-NEXT:    addiw a1, a1, 257
+; RV64ZBB-NEXT:    and a0, a0, a1
+; RV64ZBB-NEXT:    orc.b a0, a0
+; RV64ZBB-NEXT:    ret
+  %1 = and i32 %a, 16843009
+  %2 = mul nuw i32 %1, 255
+  ret i32 %2
+}
+
+define i64 @orc_b_i64(i64 %a) {
+; RV64I-LABEL: orc_b_i64:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    lui a1, 4112
+; RV64I-NEXT:    addiw a1, a1, 257
+; RV64I-NEXT:    slli a2, a1, 32
+; RV64I-NEXT:    add a1, a1, a2
+; RV64I-NEXT:    and a0, a0, a1
+; RV64I-NEXT:    slli a1, a0, 8
+; RV64I-NEXT:    sub a0, a1, a0
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: orc_b_i64:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    lui a1, 4112
+; RV64ZBB-NEXT:    addiw a1, a1, 257
+; RV64ZBB-NEXT:    slli a2, a1, 32
+; RV64ZBB-NEXT:    add a1, a1, a2
+; RV64ZBB-NEXT:    and a0, a0, a1
+; RV64ZBB-NEXT:    orc.b a0, a0
+; RV64ZBB-NEXT:    ret
+  %1 = and i64 %a, 72340172838076673
+  %2 = mul nuw i64 %1, 255
+  ret i64 %2
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96680


More information about the llvm-commits mailing list