[llvm] f20619c - [RISCV] More explicitly check that combineOp_VLToVWOp_VL removes the extends it is supposed to. (#166710)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 10:16:51 PST 2025


Author: Craig Topper
Date: 2025-11-06T10:16:47-08:00
New Revision: f20619c610f58e04633b1b053f323fe46a30c8d1

URL: https://github.com/llvm/llvm-project/commit/f20619c610f58e04633b1b053f323fe46a30c8d1
DIFF: https://github.com/llvm/llvm-project/commit/f20619c610f58e04633b1b053f323fe46a30c8d1.diff

LOG: [RISCV] More explicitly check that combineOp_VLToVWOp_VL removes the extends it is supposed to. (#166710)

If we visit multiple root nodes, make sure the strategy chosen
for other nodes includes the nodes we've already committed to remove.
    
This can occur if have add/sub nodes where one operand is a zext
and the other is a sext. We might see that the nodes share a common
extension but pick a strategy that doesn't share it.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 995ae75da1c30..3b69edacb8982 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17867,6 +17867,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
 
   SmallVector<SDNode *> Worklist;
   SmallPtrSet<SDNode *, 8> Inserted;
+  SmallPtrSet<SDNode *, 8> ExtensionsToRemove;
   Worklist.push_back(N);
   Inserted.insert(N);
   SmallVector<CombineResult> CombinesToApply;
@@ -17876,22 +17877,25 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
 
     NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
     NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
-    auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
-                                &Inserted](const NodeExtensionHelper &Op) {
-      if (Op.needToPromoteOtherUsers()) {
-        for (SDUse &Use : Op.OrigOperand->uses()) {
-          SDNode *TheUser = Use.getUser();
-          if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget))
-            return false;
-          // We only support the first 2 operands of FMA.
-          if (Use.getOperandNo() >= 2)
-            return false;
-          if (Inserted.insert(TheUser).second)
-            Worklist.push_back(TheUser);
-        }
-      }
-      return true;
-    };
+    auto AppendUsersIfNeeded =
+        [&Worklist, &Subtarget, &Inserted,
+         &ExtensionsToRemove](const NodeExtensionHelper &Op) {
+          if (Op.needToPromoteOtherUsers()) {
+            // Remember that we're supposed to remove this extension.
+            ExtensionsToRemove.insert(Op.OrigOperand.getNode());
+            for (SDUse &Use : Op.OrigOperand->uses()) {
+              SDNode *TheUser = Use.getUser();
+              if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget))
+                return false;
+              // We only support the first 2 operands of FMA.
+              if (Use.getOperandNo() >= 2)
+                return false;
+              if (Inserted.insert(TheUser).second)
+                Worklist.push_back(TheUser);
+            }
+          }
+          return true;
+        };
 
     // Control the compile time by limiting the number of node we look at in
     // total.
@@ -17912,6 +17916,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
         std::optional<CombineResult> Res =
             FoldingStrategy(Root, LHS, RHS, DAG, Subtarget);
         if (Res) {
+          // If this strategy wouldn't remove an extension we're supposed to
+          // remove, reject it.
+          if (!Res->LHSExt.has_value() &&
+              ExtensionsToRemove.contains(LHS.OrigOperand.getNode()))
+            continue;
+          if (!Res->RHSExt.has_value() &&
+              ExtensionsToRemove.contains(RHS.OrigOperand.getNode()))
+            continue;
+
           Matched = true;
           CombinesToApply.push_back(*Res);
           // All the inputs that are extended need to be folded, otherwise

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index b1f0eee3e9f52..034186210513c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -595,12 +595,11 @@ define <vscale x 4 x i32> @mismatched_extend_sub_add_commuted(<vscale x 4 x i16>
 ; FOLDING:       # %bb.0:
 ; FOLDING-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
 ; FOLDING-NEXT:    vzext.vf2 v10, v8
-; FOLDING-NEXT:    vsext.vf2 v12, v9
 ; FOLDING-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
-; FOLDING-NEXT:    vwsub.wv v10, v10, v9
-; FOLDING-NEXT:    vwaddu.wv v12, v12, v8
+; FOLDING-NEXT:    vwsub.wv v12, v10, v9
+; FOLDING-NEXT:    vwadd.wv v10, v10, v9
 ; FOLDING-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
-; FOLDING-NEXT:    vmul.vv v8, v10, v12
+; FOLDING-NEXT:    vmul.vv v8, v12, v10
 ; FOLDING-NEXT:    ret
   %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
   %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>


        


More information about the llvm-commits mailing list