[llvm] [RISCV] More explicitly check that combineOp_VLToVWOp_VL removes the extends it is supposed to. (PR #166710)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 5 21:56:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

If we visit multiple root nodes, make sure the strategy chosen
for other nodes includes the nodes we've already committed to remove.
    
This can occur if have add/sub nodes where one operand is a zext
and the other is a sext. We might see that the nodes share a common
extension but pick a strategy that doesn't share it.
    
Stacked on #<!-- -->166700

---
Full diff: https://github.com/llvm/llvm-project/pull/166710.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+13-1) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+75) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c3f100e3197b1..3b5dfd5845f0b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17876,6 +17876,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
 
   SmallVector<SDNode *> Worklist;
   SmallPtrSet<SDNode *, 8> Inserted;
+  SmallPtrSet<SDNode *, 8> ExtensionsToRemove;
   Worklist.push_back(N);
   Inserted.insert(N);
   SmallVector<CombineResult> CombinesToApply;
@@ -17886,8 +17887,10 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
     NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
     NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
     auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
-                                &Inserted](const NodeExtensionHelper &Op) {
+                                &Inserted, &ExtensionsToRemove](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
+        // Remember that we're supposed to remove this extension.
+        ExtensionsToRemove.insert(Op.OrigOperand.getNode());
         for (SDUse &Use : Op.OrigOperand->uses()) {
           SDNode *TheUser = Use.getUser();
           if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget))
@@ -17921,6 +17924,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
         std::optional<CombineResult> Res =
             FoldingStrategy(Root, LHS, RHS, DAG, Subtarget);
         if (Res) {
+          // If this strategy wouldn't remove an extension we're supposed to
+          // remove, reject it.
+          if (!Res->LHSExt.has_value() &&
+              ExtensionsToRemove.contains(LHS.OrigOperand.getNode()))
+            continue;
+          if (!Res->RHSExt.has_value() &&
+              ExtensionsToRemove.contains(RHS.OrigOperand.getNode()))
+            continue;
+
           Matched = true;
           CombinesToApply.push_back(*Res);
           // All the inputs that are extended need to be folded, otherwise
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index ad2ed47e67e64..034186210513c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -570,7 +570,82 @@ define <vscale x 2 x i32> @vwop_vscale_zext_i8i32_multiple_users(ptr %x, ptr %y,
   ret <vscale x 2 x i32> %i
 }
 
+define <vscale x 4 x i32> @mismatched_extend_sub_add(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_sub_add:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vzext.vf2 v10, v8
+; FOLDING-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT:    vwsub.wv v12, v10, v9
+; FOLDING-NEXT:    vwadd.wv v10, v10, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vmul.vv v8, v12, v10
+; FOLDING-NEXT:    ret
+  %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+  %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+  %c = sub <vscale x 4 x i32> %a, %b
+  %d = add <vscale x 4 x i32> %a, %b
+  %e = mul <vscale x 4 x i32> %c, %d
+  ret <vscale x 4 x i32> %e
+}
+
+; FIXME: this should remove the vsext
+define <vscale x 4 x i32> @mismatched_extend_sub_add_commuted(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_sub_add_commuted:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vzext.vf2 v10, v8
+; FOLDING-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT:    vwsub.wv v12, v10, v9
+; FOLDING-NEXT:    vwadd.wv v10, v10, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vmul.vv v8, v12, v10
+; FOLDING-NEXT:    ret
+  %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+  %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+  %c = sub <vscale x 4 x i32> %a, %b
+  %d = add <vscale x 4 x i32> %b, %a
+  %e = mul <vscale x 4 x i32> %c, %d
+  ret <vscale x 4 x i32> %e
+}
 
+define <vscale x 4 x i32> @mismatched_extend_add_sub(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_add_sub:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vzext.vf2 v10, v8
+; FOLDING-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT:    vwadd.wv v12, v10, v9
+; FOLDING-NEXT:    vwsub.wv v10, v10, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vmul.vv v8, v12, v10
+; FOLDING-NEXT:    ret
+  %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+  %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+  %c = add <vscale x 4 x i32> %a, %b
+  %d = sub <vscale x 4 x i32> %a, %b
+  %e = mul <vscale x 4 x i32> %c, %d
+  ret <vscale x 4 x i32> %e
+}
+
+define <vscale x 4 x i32> @mismatched_extend_add_sub_commuted(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_add_sub_commuted:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vzext.vf2 v10, v8
+; FOLDING-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT:    vwadd.wv v12, v10, v9
+; FOLDING-NEXT:    vwsub.wv v10, v10, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT:    vmul.vv v8, v12, v10
+; FOLDING-NEXT:    ret
+  %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+  %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+  %c = add <vscale x 4 x i32> %a, %b
+  %d = sub <vscale x 4 x i32> %a, %b
+  %e = mul <vscale x 4 x i32> %c, %d
+  ret <vscale x 4 x i32> %e
+}
 
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV32: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/166710


More information about the llvm-commits mailing list