[clang-tools-extra] [RISCV][GISel] Add ISel supports for SHXADD from Zba extension (PR #67863)

Min-Yih Hsu via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 18 15:36:30 PDT 2023


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/67863

>From 08f77d6a53dadd4c136b92fcb60700fd7389eeb3 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 29 Sep 2023 15:17:43 -0700
Subject: [PATCH 1/7] [RISCV][GISel] Add ISel supports for SHXADD from Zba
 extension

This patch constitue of porting (SDISel) patterns of SHXADD
instructions.
Note that `non_imm12`, a predicate that was implemented with `PatLeaf`,
is now turned into a ComplexPattern to facilitate code reusing on
patterns that use it between SDISel and GISel.
---
 .../RISCV/GISel/RISCVInstructionSelector.cpp  | 130 +++++++++++++++
 llvm/lib/Target/RISCV/RISCVGISel.td           |  10 ++
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   |   9 ++
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h     |   2 +
 llvm/lib/Target/RISCV/RISCVInstrInfoZb.td     |  51 +++---
 .../instruction-select/zba-rv32.mir           | 152 ++++++++++++++++++
 .../instruction-select/zba-rv64.mir           | 152 ++++++++++++++++++
 7 files changed, 479 insertions(+), 27 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir
 create mode 100644 llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 4f97a0d84f686f9..3a98e84546f376f 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -17,6 +17,7 @@
 #include "RISCVTargetMachine.h"
 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
 #include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/IR/IntrinsicsRISCV.h"
 #include "llvm/Support/Debug.h"
@@ -55,6 +56,14 @@ class RISCVInstructionSelector : public InstructionSelector {
 
   ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
 
+  ComplexRendererFns selectNonImm12(MachineOperand &Root) const;
+
+  ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
+  template <unsigned ShAmt>
+  ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
+    return selectSHXADDOp(Root, ShAmt);
+  }
+
   // Custom renderers for tablegen
   void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
                     int OpIdx) const;
@@ -105,6 +114,127 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
   return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
 }
 
+// This complex pattern actually serves as a perdicate that is effectively
+// `!isInt<12>(Imm)`.
+InstructionSelector::ComplexRendererFns
+RISCVInstructionSelector::selectNonImm12(MachineOperand &Root) const {
+  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  if (Root.isReg() && Root.getReg())
+    if (auto Val = getIConstantVRegValWithLookThrough(Root.getReg(), MRI)) {
+      // We do NOT want immediates that fit in 12 bits.
+      if (isInt<12>(Val->Value.getSExtValue()))
+        return std::nullopt;
+    }
+
+  return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
+}
+
+InstructionSelector::ComplexRendererFns
+RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
+                                         unsigned ShAmt) const {
+  using namespace llvm::MIPatternMatch;
+  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  if (!Root.isReg())
+    return std::nullopt;
+  Register RootReg = Root.getReg();
+
+  const unsigned XLen = STI.getXLen();
+  APInt Mask, C2;
+  Register RegY;
+  std::optional<bool> LeftShift;
+  // (and (shl y, c2), mask)
+  if (mi_match(RootReg, MRI,
+               m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
+    LeftShift = true;
+  // (and (lshr y, c2), mask)
+  else if (mi_match(RootReg, MRI,
+                    m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
+    LeftShift = false;
+
+  if (LeftShift.has_value()) {
+    if (*LeftShift)
+      Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
+    else
+      Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue());
+
+    if (Mask.isShiftedMask()) {
+      unsigned Leading = XLen - Mask.getActiveBits();
+      unsigned Trailing = Mask.countr_zero();
+      // Given (and (shl y, c2), mask) in which mask has no leading zeros and c3
+      // trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
+      if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
+              .addImm(Trailing - C2.getLimitedValue());
+          MIB.addReg(DstReg);
+        }}};
+      }
+
+      // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and c3
+      // trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
+      if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
+              .addImm(Leading + Trailing);
+          MIB.addReg(DstReg);
+        }}};
+      }
+    }
+  }
+
+  LeftShift.reset();
+
+  // (shl (and y, mask), c2)
+  if (mi_match(RootReg, MRI,
+               m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
+                      m_ICst(C2))))
+    LeftShift = true;
+  // (lshr (and y, mask), c2)
+  else if (mi_match(RootReg, MRI,
+                    m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
+                            m_ICst(C2))))
+    LeftShift = false;
+
+  if (LeftShift.has_value())
+    if (Mask.isShiftedMask()) {
+      unsigned Leading = XLen - Mask.getActiveBits();
+      unsigned Trailing = Mask.countr_zero();
+
+      // Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
+      // c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
+      bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
+                  (Trailing + C2.getLimitedValue()) == ShAmt;
+      if (!Cond)
+        // Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
+        // c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
+        Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
+               (Trailing - C2.getLimitedValue()) == ShAmt;
+
+      if (Cond) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
+              .addImm(Trailing);
+          MIB.addReg(DstReg);
+        }}};
+      }
+    }
+
+  return std::nullopt;
+}
+
 // Tablegen doesn't allow us to write SRLIW/SRAIW/SLLIW patterns because the
 // immediate Operand has type XLenVT. GlobalISel wants it to be i32.
 bool RISCVInstructionSelector::earlySelectShift(
diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 8059b517f26ba3c..2d6a293c2cca148 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -31,6 +31,16 @@ def ShiftMaskGI :
     GIComplexOperandMatcher<s32, "selectShiftMask">,
     GIComplexPatternEquiv<shiftMaskXLen>;
 
+def gi_non_imm12 : GIComplexOperandMatcher<s32, "selectNonImm12">,
+                   GIComplexPatternEquiv<non_imm12>;
+
+def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
+                   GIComplexPatternEquiv<sh1add_op>;
+def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
+                   GIComplexPatternEquiv<sh2add_op>;
+def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
+                   GIComplexPatternEquiv<sh3add_op>;
+
 // FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
 def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
           (ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 70b9041852f91f8..de04f4c12e5e8e2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2443,6 +2443,15 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm(SDValue Addr, SDValue &Base,
   return true;
 }
 
+bool RISCVDAGToDAGISel::selectNonImm12(SDValue N, SDValue &Opnd) {
+  auto *C = dyn_cast<ConstantSDNode>(N);
+  if (!C || !isInt<12>(C->getSExtValue())) {
+    Opnd = N;
+    return true;
+  }
+  return false;
+}
+
 bool RISCVDAGToDAGISel::selectShiftMask(SDValue N, unsigned ShiftWidth,
                                         SDValue &ShAmt) {
   ShAmt = N;
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
index c220b2d57c2e50f..d3d095a370683df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -83,6 +83,8 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
   bool trySignedBitfieldExtract(SDNode *Node);
   bool tryIndexedLoad(SDNode *Node);
 
+  bool selectNonImm12(SDValue N, SDValue &Opnd);
+
   bool selectShiftMask(SDValue N, unsigned ShiftWidth, SDValue &ShAmt);
   bool selectShiftMaskXLen(SDValue N, SDValue &ShAmt) {
     return selectShiftMask(N, Subtarget->getXLen(), ShAmt);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index a21c3d132636bea..c20c3176bb27dbc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -235,10 +235,7 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
 }]>;
 
 // Pattern to exclude simm12 immediates from matching.
-def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
-  auto *C = dyn_cast<ConstantSDNode>(N);
-  return !C || !isInt<12>(C->getSExtValue());
-}]>;
+def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>;
 
 def Shifted32OnesMask : PatLeaf<(imm), [{
   uint64_t Imm = N->getZExtValue();
@@ -651,19 +648,19 @@ let Predicates = [HasStdExtZbb, IsRV64] in
 def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;
 
 let Predicates = [HasStdExtZba] in {
-def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH1ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH2ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH3ADD GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh1add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh2add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh3add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH3ADD sh3add_op:$rs1, GPR:$rs2)>;
 
 def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
@@ -735,48 +732,48 @@ def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
           (SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
                    Shifted32OnesMask:$mask)>;
 
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;
 
-def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh1add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh2add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh3add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;
 
 // Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
-def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;
 
 def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir
new file mode 100644
index 000000000000000..f90de3ea55a1bb7
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir
@@ -0,0 +1,152 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
+# RUN: llc -mtriple=riscv32 -mattr='+zba' -run-pass=instruction-select -simplify-mir -verify-machineinstrs %s -o - \
+# RUN: | FileCheck %s
+
+---
+name:            sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH1ADD:%[0-9]+]]:gpr = SH1ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH1ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            sh2add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh2add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 2
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            sh3add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh3add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH3ADD:%[0-9]+]]:gpr = SH3ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH3ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 3
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            no_sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: no_sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[SLLI:%[0-9]+]]:gpr = SLLI [[COPY]], 1
+    ; CHECK-NEXT: [[ADDI:%[0-9]+]]:gpr = ADDI [[SLLI]], 37
+    ; CHECK-NEXT: $x10 = COPY [[ADDI]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = G_CONSTANT i32 37
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            shXadd_complex_shl_and
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_shl_and
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLI:%[0-9]+]]:gpr = SRLI [[COPY]], 1
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[SRLI]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_CONSTANT i32 4294967292
+    %5:gprb(s32) = G_AND %3, %4
+
+    %6:gprb(s32) = G_ADD %5, %1
+    $x10 = COPY %6(s32)
+...
+---
+name:            shXadd_complex_lshr_and
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_lshr_and
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLI:%[0-9]+]]:gpr = SRLI [[COPY]], 29
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[SRLI]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+
+    %2:gprb(s32) = G_CONSTANT i32 27
+    %3:gprb(s32) = G_LSHR %0, %2
+    %4:gprb(s32) = G_CONSTANT i32 60
+    %5:gprb(s32) = G_AND %3, %4
+
+    %6:gprb(s32) = G_ADD %5, %1
+    $x10 = COPY %6(s32)
+...
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
new file mode 100644
index 000000000000000..092a3305b3453d2
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
@@ -0,0 +1,152 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
+# RUN: llc -mtriple=riscv64 -mattr='+zba' -run-pass=instruction-select -simplify-mir -verify-machineinstrs %s -o - \
+# RUN: | FileCheck %s
+
+---
+name:            sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH1ADD:%[0-9]+]]:gpr = SH1ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH1ADD]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+    %2:gprb(s64) = G_CONSTANT i64 1
+    %3:gprb(s64) = G_SHL %0, %2
+    %4:gprb(s64) = G_ADD %3, %1
+    $x10 = COPY %4(s64)
+...
+---
+name:            sh2add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh2add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+    %2:gprb(s64) = G_CONSTANT i64 2
+    %3:gprb(s64) = G_SHL %0, %2
+    %4:gprb(s64) = G_ADD %3, %1
+    $x10 = COPY %4(s64)
+...
+---
+name:            sh3add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh3add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH3ADD:%[0-9]+]]:gpr = SH3ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH3ADD]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+    %2:gprb(s64) = G_CONSTANT i64 3
+    %3:gprb(s64) = G_SHL %0, %2
+    %4:gprb(s64) = G_ADD %3, %1
+    $x10 = COPY %4(s64)
+...
+---
+name:            no_sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: no_sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[SLLI:%[0-9]+]]:gpr = SLLI [[COPY]], 1
+    ; CHECK-NEXT: [[ADDI:%[0-9]+]]:gpr = ADDI [[SLLI]], 37
+    ; CHECK-NEXT: $x10 = COPY [[ADDI]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = G_CONSTANT i64 37
+    %2:gprb(s64) = G_CONSTANT i64 1
+    %3:gprb(s64) = G_SHL %0, %2
+    %4:gprb(s64) = G_ADD %3, %1
+    $x10 = COPY %4(s64)
+...
+---
+name:            shXadd_complex_and_shl
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_and_shl
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLIW:%[0-9]+]]:gpr = SRLIW [[COPY]], 1
+    ; CHECK-NEXT: [[SH3ADD:%[0-9]+]]:gpr = SH3ADD [[SRLIW]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH3ADD]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+
+    %2:gprb(s64) = G_CONSTANT i64 4294967294
+    %3:gprb(s64) = G_AND %0, %2
+    %4:gprb(s64) = G_CONSTANT i64 2
+    %5:gprb(s64) = G_SHL %3, %4
+
+    %6:gprb(s64) = G_ADD %5, %1
+    $x10 = COPY %6(s64)
+...
+---
+name:            shXadd_complex_and_lshr
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_and_lshr
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLIW:%[0-9]+]]:gpr = SRLIW [[COPY]], 2
+    ; CHECK-NEXT: [[SH1ADD:%[0-9]+]]:gpr = SH1ADD [[SRLIW]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH1ADD]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+
+    %2:gprb(s64) = G_CONSTANT i64 4294967292
+    %3:gprb(s64) = G_AND %0, %2
+    %4:gprb(s64) = G_CONSTANT i64 1
+    %5:gprb(s64) = G_LSHR %3, %4
+
+    %6:gprb(s64) = G_ADD %5, %1
+    $x10 = COPY %6(s64)
+...

>From 4d81ad5ee98aa284487b59ea1abef5090a746b6c Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 29 Sep 2023 15:54:05 -0700
Subject: [PATCH 2/7] fixup! [RISCV][GISel] Add ISel supports for SHXADD from
 Zba extension

---
 llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 3a98e84546f376f..3be97b016f47fea 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -164,8 +164,8 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
     if (Mask.isShiftedMask()) {
       unsigned Leading = XLen - Mask.getActiveBits();
       unsigned Trailing = Mask.countr_zero();
-      // Given (and (shl y, c2), mask) in which mask has no leading zeros and c3
-      // trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
+      // Given (and (shl y, c2), mask) in which mask has no leading zeros and
+      // c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
       if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
         Register DstReg =
             MRI.createGenericVirtualRegister(MRI.getType(RootReg));

>From 2d4dce18884979959ca9cdf1d99a3134e6efe6ac Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Mon, 2 Oct 2023 11:10:16 -0700
Subject: [PATCH 3/7] (Staging) Use GISelPredicateCode in all SHXADD patterns

But since there is a bug in llvm-tblgen that crashes itself whenever a
ComplexPattern failed to be imported with `PredicateUsesOperands` +
`GISelPredicateCode`, we preserve the original `non_imm12` (PatLeaf) and
leave all `SHXADD_UW` patterns untouched.
---
 llvm/lib/Target/RISCV/RISCVGISel.td       |  3 -
 llvm/lib/Target/RISCV/RISCVInstrInfoZb.td | 91 ++++++++++++++---------
 2 files changed, 56 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 2d6a293c2cca148..e0bc25c570cd209 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -31,9 +31,6 @@ def ShiftMaskGI :
     GIComplexOperandMatcher<s32, "selectShiftMask">,
     GIComplexPatternEquiv<shiftMaskXLen>;
 
-def gi_non_imm12 : GIComplexOperandMatcher<s32, "selectNonImm12">,
-                   GIComplexPatternEquiv<non_imm12>;
-
 def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
                    GIComplexPatternEquiv<sh1add_op>;
 def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index c20c3176bb27dbc..6a1e8531c1650b2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -235,7 +235,33 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
 }]>;
 
 // Pattern to exclude simm12 immediates from matching.
-def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>;
+def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
+  auto *C = dyn_cast<ConstantSDNode>(N);
+  return !C || !isInt<12>(C->getSExtValue());
+}]>;
+
+class binop_with_non_imm12<SDPatternOperator binop> : PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
+  auto *C = dyn_cast<ConstantSDNode>(Operands[1]);
+  return !C || !isInt<12>(C->getSExtValue());
+}]> {
+  let PredicateCodeUsesOperands = 1;
+  let GISelPredicateCode = [{
+    const MachineOperand &ImmOp = *Operands[1];
+    const MachineFunction &MF = *MI.getParent()->getParent();
+    const MachineRegisterInfo &MRI = MF.getRegInfo();
+
+    if (ImmOp.isReg() && ImmOp.getReg())
+      if (auto Val = getIConstantVRegValWithLookThrough(ImmOp.getReg(), MRI)) {
+        // We do NOT want immediates that fit in 12 bits.
+        return !isInt<12>(Val->Value.getSExtValue());
+      }
+
+    return true;
+  }];
+}
+def add_non_imm12 : binop_with_non_imm12<add>;
+def or_is_add_non_imm12 : binop_with_non_imm12<or_is_add>;
+
 
 def Shifted32OnesMask : PatLeaf<(imm), [{
   uint64_t Imm = N->getZExtValue();
@@ -648,20 +674,17 @@ let Predicates = [HasStdExtZbb, IsRV64] in
 def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;
 
 let Predicates = [HasStdExtZba] in {
-def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), (non_imm12 (XLenVT GPR:$rs2))),
-          (SH1ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), (non_imm12 (XLenVT GPR:$rs2))),
-          (SH2ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), (non_imm12 (XLenVT GPR:$rs2))),
-          (SH3ADD GPR:$rs1, GPR:$rs2)>;
 
-// More complex cases use a ComplexPattern.
-def : Pat<(add sh1add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
-          (SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh2add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
-          (SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh3add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
-          (SH3ADD sh3add_op:$rs1, GPR:$rs2)>;
+foreach i = {1,2,3} in {
+  defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
+  def : Pat<(XLenVT (add_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
+            (shxadd GPR:$rs1, GPR:$rs2)>;
+
+  defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
+  // More complex cases use a ComplexPattern.
+  def : Pat<(XLenVT (add_non_imm12 pat:$rs1, GPR:$rs2)),
+            (shxadd pat:$rs1, GPR:$rs2)>;
+}
 
 def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
           (SH1ADD (SH1ADD GPR:$rs1, GPR:$rs1), GPR:$rs2)>;
@@ -731,49 +754,47 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xFFFFFFFF), uimm5:$shamt)),
 def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
           (SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
                    Shifted32OnesMask:$mask)>;
-
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;
 
-def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (or_is_add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), (non_imm12 (XLenVT GPR:$rs2)))),
-          (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), (non_imm12 (XLenVT GPR:$rs2)))),
-          (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), (non_imm12 (XLenVT GPR:$rs2)))),
-          (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
+foreach i = {1,2,3} in {
+  defvar shxadd_uw = !cast<Instruction>("SH"#i#"ADD_UW");
+  def : Pat<(i64 (add_non_imm12 (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 i)), (XLenVT GPR:$rs2))),
+            (shxadd_uw GPR:$rs1, GPR:$rs2)>;
+}
 
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (XLenVT GPR:$rs2))),
           (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (XLenVT GPR:$rs2))),
           (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (XLenVT GPR:$rs2))),
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(i64 (add sh1add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)),
           (SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh2add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
           (SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh3add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
           (SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFE), (XLenVT GPR:$rs2))),
           (SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFC), (XLenVT GPR:$rs2))),
           (SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFF8), (XLenVT GPR:$rs2))),
           (SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;
 
 // Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
-def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x1FFFFFFFE), (XLenVT GPR:$rs2))),
           (SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x3FFFFFFFC), (XLenVT GPR:$rs2))),
           (SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
+def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))),
           (SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;
 
 def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),

>From 9de0c2f758f379c0c1a620223364433e34980b1d Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 3 Oct 2023 15:02:26 -0700
Subject: [PATCH 4/7] fixup! (Staging) Use GISelPredicateCode in all SHXADD
 patterns

---
 .../RISCV/GISel/RISCVInstructionSelector.cpp  | 19 -------------------
 1 file changed, 19 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 3be97b016f47fea..96498d3cbab0190 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -56,8 +56,6 @@ class RISCVInstructionSelector : public InstructionSelector {
 
   ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
 
-  ComplexRendererFns selectNonImm12(MachineOperand &Root) const;
-
   ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
   template <unsigned ShAmt>
   ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
@@ -114,23 +112,6 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
   return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
 }
 
-// This complex pattern actually serves as a perdicate that is effectively
-// `!isInt<12>(Imm)`.
-InstructionSelector::ComplexRendererFns
-RISCVInstructionSelector::selectNonImm12(MachineOperand &Root) const {
-  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-
-  if (Root.isReg() && Root.getReg())
-    if (auto Val = getIConstantVRegValWithLookThrough(Root.getReg(), MRI)) {
-      // We do NOT want immediates that fit in 12 bits.
-      if (isInt<12>(Val->Value.getSExtValue()))
-        return std::nullopt;
-    }
-
-  return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
-}
-
 InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
                                          unsigned ShAmt) const {

>From 0b2e658dcd74974518a3ad031185895dbca768e6 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 3 Oct 2023 15:13:35 -0700
Subject: [PATCH 5/7] fixup! (Staging) Use GISelPredicateCode in all SHXADD
 patterns

---
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp |  9 ---------
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h   |  2 --
 llvm/lib/Target/RISCV/RISCVInstrInfoZb.td   | 11 ++++++++---
 3 files changed, 8 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index de04f4c12e5e8e2..70b9041852f91f8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2443,15 +2443,6 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm(SDValue Addr, SDValue &Base,
   return true;
 }
 
-bool RISCVDAGToDAGISel::selectNonImm12(SDValue N, SDValue &Opnd) {
-  auto *C = dyn_cast<ConstantSDNode>(N);
-  if (!C || !isInt<12>(C->getSExtValue())) {
-    Opnd = N;
-    return true;
-  }
-  return false;
-}
-
 bool RISCVDAGToDAGISel::selectShiftMask(SDValue N, unsigned ShiftWidth,
                                         SDValue &ShAmt) {
   ShAmt = N;
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
index d3d095a370683df..c220b2d57c2e50f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -83,8 +83,6 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
   bool trySignedBitfieldExtract(SDNode *Node);
   bool tryIndexedLoad(SDNode *Node);
 
-  bool selectNonImm12(SDValue N, SDValue &Opnd);
-
   bool selectShiftMask(SDValue N, unsigned ShiftWidth, SDValue &ShAmt);
   bool selectShiftMaskXLen(SDValue N, SDValue &ShAmt) {
     return selectShiftMask(N, Subtarget->getXLen(), ShAmt);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index 6a1e8531c1650b2..f8b4bc4945eb0a4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -235,12 +235,18 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
 }]>;
 
 // Pattern to exclude simm12 immediates from matching.
+// Note: this will be removed once the GISel complex patterns for
+// SHXADD_UW is landed.
 def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
   auto *C = dyn_cast<ConstantSDNode>(N);
   return !C || !isInt<12>(C->getSExtValue());
 }]>;
 
-class binop_with_non_imm12<SDPatternOperator binop> : PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
+// GISel currently doesn't support PatFrag for leaf nodes, so `non_imm12`
+// cannot be directly supported in GISel. To reuse patterns between the two
+// ISels, we instead create PatFrag on operators that use `non_imm12`.
+class binop_with_non_imm12<SDPatternOperator binop>
+  : PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
   auto *C = dyn_cast<ConstantSDNode>(Operands[1]);
   return !C || !isInt<12>(C->getSExtValue());
 }]> {
@@ -259,10 +265,9 @@ class binop_with_non_imm12<SDPatternOperator binop> : PatFrag<(ops node:$x, node
     return true;
   }];
 }
-def add_non_imm12 : binop_with_non_imm12<add>;
+def add_non_imm12       : binop_with_non_imm12<add>;
 def or_is_add_non_imm12 : binop_with_non_imm12<or_is_add>;
 
-
 def Shifted32OnesMask : PatLeaf<(imm), [{
   uint64_t Imm = N->getZExtValue();
   if (!isShiftedMask_64(Imm))

>From 5484c7e950b8afcfdce9ab4841cd765751c8e4bd Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Wed, 18 Oct 2023 13:40:11 -0700
Subject: [PATCH 6/7] fixup! (Staging) Use GISelPredicateCode in all SHXADD
 patterns

---
 .../RISCV/GISel/RISCVInstructionSelector.cpp  | 49 +++++++++----------
 1 file changed, 24 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 96498d3cbab0190..0cec3a2f215e6cc 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -186,32 +186,31 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
                             m_ICst(C2))))
     LeftShift = false;
 
-  if (LeftShift.has_value())
-    if (Mask.isShiftedMask()) {
-      unsigned Leading = XLen - Mask.getActiveBits();
-      unsigned Trailing = Mask.countr_zero();
-
-      // Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
-      // c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
-      bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
-                  (Trailing + C2.getLimitedValue()) == ShAmt;
-      if (!Cond)
-        // Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
-        // c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
-        Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
-               (Trailing - C2.getLimitedValue()) == ShAmt;
-
-      if (Cond) {
-        Register DstReg =
-            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
-        return {{[=](MachineInstrBuilder &MIB) {
-          MachineIRBuilder(*MIB.getInstr())
-              .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
-              .addImm(Trailing);
-          MIB.addReg(DstReg);
-        }}};
-      }
+  if (LeftShift.has_value() && Mask.isShiftedMask()) {
+    unsigned Leading = XLen - Mask.getActiveBits();
+    unsigned Trailing = Mask.countr_zero();
+
+    // Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
+    // c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
+    bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
+                (Trailing + C2.getLimitedValue()) == ShAmt;
+    if (!Cond)
+      // Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
+      // c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
+      Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
+             (Trailing - C2.getLimitedValue()) == ShAmt;
+
+    if (Cond) {
+      Register DstReg =
+          MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+      return {{[=](MachineInstrBuilder &MIB) {
+        MachineIRBuilder(*MIB.getInstr())
+            .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
+            .addImm(Trailing);
+        MIB.addReg(DstReg);
+      }}};
     }
+  }
 
   return std::nullopt;
 }

>From d8edabad24e3ef0418c5e31190f243037669238f Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Wed, 18 Oct 2023 15:35:36 -0700
Subject: [PATCH 7/7] fixup! (Staging) Use GISelPredicateCode in all SHXADD
 patterns

---
 llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 1ecca6745596f0b..e5354905eaffcc9 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -175,8 +175,8 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
         }}};
       }
 
-      // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and c3
-      // trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
+      // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and
+      // c3 trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
       if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
         Register DstReg =
             MRI.createGenericVirtualRegister(MRI.getType(RootReg));
@@ -218,8 +218,7 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
              (Trailing - C2.getLimitedValue()) == ShAmt;
 
     if (Cond) {
-      Register DstReg =
-          MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+      Register DstReg = MRI.createGenericVirtualRegister(MRI.getType(RootReg));
       return {{[=](MachineInstrBuilder &MIB) {
         MachineIRBuilder(*MIB.getInstr())
             .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})



More information about the cfe-commits mailing list