[llvm] c5c2de2 - [RISCV][ISel] Fold extensions when all the users can consume them

Quentin Colombet via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 5 13:50:17 PDT 2022


Author: Quentin Colombet
Date: 2022-10-05T20:49:21Z
New Revision: c5c2de287e5ff1803a10d94b0ab17b579442726d

URL: https://github.com/llvm/llvm-project/commit/c5c2de287e5ff1803a10d94b0ab17b579442726d
DIFF: https://github.com/llvm/llvm-project/commit/c5c2de287e5ff1803a10d94b0ab17b579442726d.diff

LOG: [RISCV][ISel] Fold extensions when all the users can consume them

This patch allows the combines that fold extensions in binary operations
to have more than one use.
The approach here is pretty conservative: if all the users of an
extension can fold the extension, then the folding is done, otherwise we
don't fold.
This is the first step towards avoiding the one-use limitation.

As a result, we make a decision to fold/don't fold for a web of
instructions. An instruction is part of the web of instructions as soon
as it consumes an extension that needs to be folded for all its users.

Because of how SDISel works a web of instructions can be visited over
and over. More precisely, if the folding happens, it happens for the
whole web and that's the end of it, but if the folding fails, the whole
web may be revisited when another member of the web is visited.

To avoid a compile time explosion in pathological cases, we bail out
earlier for webs that are bigger than a given threshold (arbitrarily set
at 18 for now.) This size can be changed using
`--riscv-lower-ext-max-web-size=<maxWebSize>`.

At the current time, I didn't see a better scheme for that. Assuming we
want to stick with doing that in SDISel.

Differential Revision: https://reviews.llvm.org/D133739

Added: 
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f4f74eb115947..4a8bc3177c85e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -46,6 +46,12 @@ using namespace llvm;
 
 STATISTIC(NumTailCalls, "Number of tail calls");
 
+static cl::opt<unsigned> ExtensionMaxWebSize(
+    DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
+    cl::desc("Give the maximum size (in number of nodes) of the web of "
+             "instructions that we will consider for VW expansion"),
+    cl::init(18));
+
 static cl::opt<bool>
     AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
                      cl::desc("Allow the formation of VW_W operations (e.g., "
@@ -8547,9 +8553,9 @@ struct CombineResult {
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
-  const NodeExtensionHelper &LHS;
+  NodeExtensionHelper LHS;
   /// RHS of the TargetOpcode.
-  const NodeExtensionHelper &RHS;
+  NodeExtensionHelper RHS;
 
   CombineResult(unsigned TargetOpcode, SDNode *Root,
                 const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
@@ -8728,31 +8734,83 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
   assert(NodeExtensionHelper::isSupportedRoot(N) &&
          "Shouldn't have called this method");
+  SmallVector<SDNode *> Worklist;
+  SmallSet<SDNode *, 8> Inserted;
+  Worklist.push_back(N);
+  Inserted.insert(N);
+  SmallVector<CombineResult> CombinesToApply;
+
+  while (!Worklist.empty()) {
+    SDNode *Root = Worklist.pop_back_val();
+    if (!NodeExtensionHelper::isSupportedRoot(Root))
+      return SDValue();
 
-  NodeExtensionHelper LHS(N, 0, DAG);
-  NodeExtensionHelper RHS(N, 1, DAG);
-
-  if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse())
-    return SDValue();
-
-  if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse())
-    return SDValue();
+    NodeExtensionHelper LHS(N, 0, DAG);
+    NodeExtensionHelper RHS(N, 1, DAG);
+    auto AppendUsersIfNeeded = [&Worklist,
+                                &Inserted](const NodeExtensionHelper &Op) {
+      if (Op.needToPromoteOtherUsers()) {
+        for (SDNode *TheUse : Op.OrigOperand->uses()) {
+          if (Inserted.insert(TheUse).second)
+            Worklist.push_back(TheUse);
+        }
+      }
+    };
+    AppendUsersIfNeeded(LHS);
+    AppendUsersIfNeeded(RHS);
 
-  SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
-      NodeExtensionHelper::getSupportedFoldings(N);
+    // Control the compile time by limiting the number of node we look at in
+    // total.
+    if (Inserted.size() > ExtensionMaxWebSize)
+      return SDValue();
 
-  assert(!FoldingStrategies.empty() && "Nothing to be folded");
-  for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N);
-       ++Attempt) {
-    for (NodeExtensionHelper::CombineToTry FoldingStrategy :
-         FoldingStrategies) {
-      Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
-      if (Res)
-        return Res->materialize(DAG);
+    SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
+        NodeExtensionHelper::getSupportedFoldings(N);
+
+    assert(!FoldingStrategies.empty() && "Nothing to be folded");
+    bool Matched = false;
+    for (int Attempt = 0;
+         (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched;
+         ++Attempt) {
+
+      for (NodeExtensionHelper::CombineToTry FoldingStrategy :
+           FoldingStrategies) {
+        Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+        if (Res) {
+          Matched = true;
+          CombinesToApply.push_back(*Res);
+          break;
+        }
+      }
+      std::swap(LHS, RHS);
     }
-    std::swap(LHS, RHS);
+    // Right now we do an all or nothing approach.
+    if (!Matched)
+      return SDValue();
   }
-  return SDValue();
+  // Store the value for the replacement of the input node separately.
+  SDValue InputRootReplacement;
+  // We do the RAUW after we materialize all the combines, because some replaced
+  // nodes may be feeding some of the yet-to-be-replaced nodes. Put 
diff erently,
+  // some of these nodes may appear in the NodeExtensionHelpers of some of the
+  // yet-to-be-visited CombinesToApply roots.
+  SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
+  ValuesToReplace.reserve(CombinesToApply.size());
+  for (CombineResult Res : CombinesToApply) {
+    SDValue NewValue = Res.materialize(DAG);
+    if (!InputRootReplacement) {
+      assert(Res.Root == N &&
+             "First element is expected to be the current node");
+      InputRootReplacement = NewValue;
+    } else {
+      ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
+    }
+  }
+  for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
+    DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
+    DCI.AddToWorklist(OldNewValues.second.getNode());
+  }
+  return InputRootReplacement;
 }
 
 // Fold

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
new file mode 100644
index 0000000000000..4fdf7374c14cb
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+
+; Check that the add/sub/mul operations are all promoted into their
+; vw counterpart when the folding of the web size is increased to 3.
+; We need the web size to be at least 3 for the folding to happen, because
+; %c has 3 uses.
+define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
+; NO_FOLDING-LABEL: vwmul_v2i16_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e16, mf4, ta, mu
+; NO_FOLDING-NEXT:    vle8.v v8, (a0)
+; NO_FOLDING-NEXT:    vle8.v v9, (a1)
+; NO_FOLDING-NEXT:    vle8.v v10, (a2)
+; NO_FOLDING-NEXT:    vsext.vf2 v11, v8
+; NO_FOLDING-NEXT:    vsext.vf2 v8, v9
+; NO_FOLDING-NEXT:    vsext.vf2 v9, v10
+; NO_FOLDING-NEXT:    vmul.vv v8, v11, v8
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v9
+; NO_FOLDING-NEXT:    vsub.vv v9, v11, v9
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v9
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vwmul_v2i16_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e8, mf8, ta, mu
+; FOLDING-NEXT:    vle8.v v8, (a0)
+; FOLDING-NEXT:    vle8.v v9, (a1)
+; FOLDING-NEXT:    vle8.v v10, (a2)
+; FOLDING-NEXT:    vwmul.vv v11, v8, v9
+; FOLDING-NEXT:    vwadd.vv v9, v8, v10
+; FOLDING-NEXT:    vwsub.vv v12, v8, v10
+; FOLDING-NEXT:    vsetvli zero, zero, e16, mf4, ta, mu
+; FOLDING-NEXT:    vor.vv v8, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    ret
+  %a = load <2 x i8>, <2 x i8>* %x
+  %b = load <2 x i8>, <2 x i8>* %y
+  %b2 = load <2 x i8>, <2 x i8>* %z
+  %c = sext <2 x i8> %a to <2 x i16>
+  %d = sext <2 x i8> %b to <2 x i16>
+  %d2 = sext <2 x i8> %b2 to <2 x i16>
+  %e = mul <2 x i16> %c, %d
+  %f = add <2 x i16> %c, %d2
+  %g = sub <2 x i16> %c, %d2
+  %h = or <2 x i16> %e, %f
+  %i = or <2 x i16> %h, %g
+  ret <2 x i16> %i
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll
index 8862e33647d4b..0026164103501 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll
@@ -21,16 +21,14 @@ define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) {
 define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
 ; CHECK-LABEL: vwmul_v2i16_multiple_users:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, mu
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, mu
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
 ; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vsext.vf2 v11, v8
-; CHECK-NEXT:    vsext.vf2 v8, v9
-; CHECK-NEXT:    vsext.vf2 v9, v10
-; CHECK-NEXT:    vmul.vv v8, v11, v8
-; CHECK-NEXT:    vmul.vv v9, v11, v9
-; CHECK-NEXT:    vor.vv v8, v8, v9
+; CHECK-NEXT:    vwmul.vv v11, v8, v9
+; CHECK-NEXT:    vwmul.vv v9, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf4, ta, mu
+; CHECK-NEXT:    vor.vv v8, v11, v9
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, <2 x i8>* %x
   %b = load <2 x i8>, <2 x i8>* %y


        


More information about the llvm-commits mailing list