[llvm] [RISCV] Support 3-argument associative add for transformAddShlImm (PR #86883)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 27 15:22:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

This transform is looking for the shNadd idiom for zba, but that can be obscured if there's another value being added to the result.  The choice to restrict to one level of association is tactical - we could of course do more, but there's the usual compile time tradeoff, and this covers the motivating example.

This is a solution to a reduced test case originally flagged in the description of https://github.com/llvm/llvm-project/pull/85734.

---
Full diff: https://github.com/llvm/llvm-project/pull/86883.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+43-11) 
- (modified) llvm/test/CodeGen/RISCV/rv64zba.ll (+4-6) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 564fda674317f4..14ace1d30a4112 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12733,20 +12733,13 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
 
 // Optimize (add (shl x, c0), (shl y, c1)) ->
 //          (SLLI (SH*ADD x, y), c0), if c1-c0 equals to [1|2|3].
-static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
+static SDValue transformAddShlImm(SDValue N0, SDValue N1, SDLoc DL,
+                                  SelectionDAG &DAG,
                                   const RISCVSubtarget &Subtarget) {
-  // Perform this optimization only in the zba extension.
-  if (!Subtarget.hasStdExtZba())
-    return SDValue();
 
-  // Skip for vector types and larger types.
-  EVT VT = N->getValueType(0);
-  if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
-    return SDValue();
+  EVT VT = N0.getValueType();
 
   // The two operand nodes must be SHL and have no other use.
-  SDValue N0 = N->getOperand(0);
-  SDValue N1 = N->getOperand(1);
   if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::SHL ||
       !N0->hasOneUse() || !N1->hasOneUse())
     return SDValue();
@@ -12768,7 +12761,6 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   // Build nodes.
-  SDLoc DL(N);
   SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0);
   SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0);
   SDValue NA0 =
@@ -12777,6 +12769,46 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT));
 }
 
+// Generalized form of above which looks through one level of add
+// reassociation for oppurtunities.
+static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
+                                  const RISCVSubtarget &Subtarget) {
+  // Perform this optimization only in the zba extension.
+  if (!Subtarget.hasStdExtZba())
+    return SDValue();
+
+  // Skip for vector types and larger types.
+  EVT VT = N->getValueType(0);
+  if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
+    return SDValue();
+
+  // We're look for two SHL nodes in the add tree with all nodes
+  // involved having no other use.
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  if (N0->getOpcode() != ISD::SHL)
+    std::swap(N0, N1);
+
+  if (SDValue Res = transformAddShlImm(N0, N1, SDLoc(N), DAG, Subtarget))
+    return Res;
+
+  if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::ADD ||
+      !N1->hasOneUse())
+    return SDValue();
+
+  // Allow reassociation for a 3-argument add
+  SDLoc DL(N1);
+  SDValue A = N1->getOperand(0);
+  SDValue B = N1->getOperand(1);
+  if (SDValue Res = transformAddShlImm(N0, A, SDLoc(N), DAG, Subtarget))
+    return DAG.getNode(ISD::ADD, DL, VT, Res, B);
+
+  if (SDValue Res = transformAddShlImm(N0, B, SDLoc(N), DAG, Subtarget))
+    return DAG.getNode(ISD::ADD, DL, VT, Res, A);
+
+  return SDValue();
+}
+
 // Combine a constant select operand into its use:
 //
 // (and (select cond, -1, c), x)
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index d9d83633a8537f..c09cf1b6f48440 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1315,9 +1315,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
 ;
 ; RV64ZBA-LABEL: sh6_sh3_add2:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a1, a0
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ret
 entry:
   %shl = shl i64 %z, 3
@@ -1360,9 +1359,8 @@ define i64 @sh6_sh3_add4(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
 ;
 ; RV64ZBA-LABEL: sh6_sh3_add4:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
-; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ret
 entry:
   %shl = shl i64 %z, 3

``````````

</details>


https://github.com/llvm/llvm-project/pull/86883


More information about the llvm-commits mailing list