[llvm] [RISCV] More explicitly check that combineOp_VLToVWOp_VL removes the extends it is supposed to. (PR #166710)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 21:56:50 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Craig Topper (topperc)
<details>
<summary>Changes</summary>
If we visit multiple root nodes, make sure the strategy chosen
for other nodes includes the nodes we've already committed to remove.
This can occur if have add/sub nodes where one operand is a zext
and the other is a sext. We might see that the nodes share a common
extension but pick a strategy that doesn't share it.
Stacked on #<!-- -->166700
---
Full diff: https://github.com/llvm/llvm-project/pull/166710.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+13-1)
- (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+75)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c3f100e3197b1..3b5dfd5845f0b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17876,6 +17876,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
SmallVector<SDNode *> Worklist;
SmallPtrSet<SDNode *, 8> Inserted;
+ SmallPtrSet<SDNode *, 8> ExtensionsToRemove;
Worklist.push_back(N);
Inserted.insert(N);
SmallVector<CombineResult> CombinesToApply;
@@ -17886,8 +17887,10 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
- &Inserted](const NodeExtensionHelper &Op) {
+ &Inserted, &ExtensionsToRemove](const NodeExtensionHelper &Op) {
if (Op.needToPromoteOtherUsers()) {
+ // Remember that we're supposed to remove this extension.
+ ExtensionsToRemove.insert(Op.OrigOperand.getNode());
for (SDUse &Use : Op.OrigOperand->uses()) {
SDNode *TheUser = Use.getUser();
if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget))
@@ -17921,6 +17924,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
std::optional<CombineResult> Res =
FoldingStrategy(Root, LHS, RHS, DAG, Subtarget);
if (Res) {
+ // If this strategy wouldn't remove an extension we're supposed to
+ // remove, reject it.
+ if (!Res->LHSExt.has_value() &&
+ ExtensionsToRemove.contains(LHS.OrigOperand.getNode()))
+ continue;
+ if (!Res->RHSExt.has_value() &&
+ ExtensionsToRemove.contains(RHS.OrigOperand.getNode()))
+ continue;
+
Matched = true;
CombinesToApply.push_back(*Res);
// All the inputs that are extended need to be folded, otherwise
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index ad2ed47e67e64..034186210513c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -570,7 +570,82 @@ define <vscale x 2 x i32> @vwop_vscale_zext_i8i32_multiple_users(ptr %x, ptr %y,
ret <vscale x 2 x i32> %i
}
+define <vscale x 4 x i32> @mismatched_extend_sub_add(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_sub_add:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vzext.vf2 v10, v8
+; FOLDING-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT: vwsub.wv v12, v10, v9
+; FOLDING-NEXT: vwadd.wv v10, v10, v9
+; FOLDING-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vmul.vv v8, v12, v10
+; FOLDING-NEXT: ret
+ %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+ %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+ %c = sub <vscale x 4 x i32> %a, %b
+ %d = add <vscale x 4 x i32> %a, %b
+ %e = mul <vscale x 4 x i32> %c, %d
+ ret <vscale x 4 x i32> %e
+}
+
+; FIXME: this should remove the vsext
+define <vscale x 4 x i32> @mismatched_extend_sub_add_commuted(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_sub_add_commuted:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vzext.vf2 v10, v8
+; FOLDING-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT: vwsub.wv v12, v10, v9
+; FOLDING-NEXT: vwadd.wv v10, v10, v9
+; FOLDING-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vmul.vv v8, v12, v10
+; FOLDING-NEXT: ret
+ %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+ %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+ %c = sub <vscale x 4 x i32> %a, %b
+ %d = add <vscale x 4 x i32> %b, %a
+ %e = mul <vscale x 4 x i32> %c, %d
+ ret <vscale x 4 x i32> %e
+}
+define <vscale x 4 x i32> @mismatched_extend_add_sub(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_add_sub:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vzext.vf2 v10, v8
+; FOLDING-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT: vwadd.wv v12, v10, v9
+; FOLDING-NEXT: vwsub.wv v10, v10, v9
+; FOLDING-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vmul.vv v8, v12, v10
+; FOLDING-NEXT: ret
+ %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+ %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+ %c = add <vscale x 4 x i32> %a, %b
+ %d = sub <vscale x 4 x i32> %a, %b
+ %e = mul <vscale x 4 x i32> %c, %d
+ ret <vscale x 4 x i32> %e
+}
+
+define <vscale x 4 x i32> @mismatched_extend_add_sub_commuted(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y) {
+; FOLDING-LABEL: mismatched_extend_add_sub_commuted:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vzext.vf2 v10, v8
+; FOLDING-NEXT: vsetvli zero, zero, e16, m1, ta, ma
+; FOLDING-NEXT: vwadd.wv v12, v10, v9
+; FOLDING-NEXT: vwsub.wv v10, v10, v9
+; FOLDING-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; FOLDING-NEXT: vmul.vv v8, v12, v10
+; FOLDING-NEXT: ret
+ %a = zext <vscale x 4 x i16> %x to <vscale x 4 x i32>
+ %b = sext <vscale x 4 x i16> %y to <vscale x 4 x i32>
+ %c = add <vscale x 4 x i32> %a, %b
+ %d = sub <vscale x 4 x i32> %a, %b
+ %e = mul <vscale x 4 x i32> %c, %d
+ ret <vscale x 4 x i32> %e
+}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; RV32: {{.*}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/166710
More information about the llvm-commits
mailing list