[llvm-branch-commits] [llvm] release/21.x: [RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL. (#159205) (PR #159891)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Sep 19 18:50:35 PDT 2025
https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/159891
Backport 6119d1f115625cd1b8a2b9d331609eb9e9f676ce
Requested by: @topperc
>From 6b036105fdfcde8b38e99fa68ccd820c5652fc3a Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 19 Sep 2025 09:19:57 -0700
Subject: [PATCH] [RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are
handled in combineOp_VLToVWOp_VL. (#159205)
These instructions have one already narrow operand. Previously, we
pretended like this operand was a supported extension.
This could cause problems when we called getOrCreateExtendedOp on this
narrow operand when creating the the VWADD_VL. If the narrow operand
happened to be an extend of the opposite type, we would peek through it
and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32
(sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16
(sext Y))).
To prevent this, we ignore the operand instead and pass std::nullopt for
SupportsExt to getOrCreateExtendedOp so it won't peek through any
extends on the narrow source.
Fixes #159152.
(cherry picked from commit 6119d1f115625cd1b8a2b9d331609eb9e9f676ce)
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 86 +++++++++++--------
.../fixed-vectors-vw-web-simplification.ll | 23 +++++
2 files changed, 72 insertions(+), 37 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5fb16f5ac6b9e..347f6c99852e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16936,18 +16936,9 @@ struct NodeExtensionHelper {
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFWSUB_W_VL:
- if (OperandIdx == 1) {
- SupportsZExt =
- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt =
- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
- SupportsFPExt =
- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
- // There's no existing extension here, so we don't have to worry about
- // making sure it gets removed.
- EnforceOneUse = false;
+ // Operand 1 can't be changed.
+ if (OperandIdx == 1)
break;
- }
[[fallthrough]];
default:
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -16985,20 +16976,20 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::OR_VL:
- case RISCVISD::VWADD_W_VL:
- case RISCVISD::VWADDU_W_VL:
case RISCVISD::FADD_VL:
case RISCVISD::FMUL_VL:
- case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
return true;
+ case RISCVISD::VWADD_W_VL:
+ case RISCVISD::VWADDU_W_VL:
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
case RISCVISD::FSUB_VL:
case RISCVISD::VFWSUB_W_VL:
case ISD::SHL:
@@ -17117,6 +17108,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
Subtarget);
}
+/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
+ Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
+ Subtarget);
+}
+
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17145,7 +17160,7 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
+/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17153,11 +17168,14 @@ static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
- Subtarget);
+ if (LHS.SupportsSExt)
+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
+/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17165,11 +17183,14 @@ static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
- Subtarget);
+ if (LHS.SupportsZExt)
+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17177,20 +17198,11 @@ static std::optional<CombineResult>
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
- Subtarget);
-}
-
-/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
-///
-/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
-/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
- Subtarget);
+ if (LHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17233,7 +17245,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
- Strategies.push_back(canFoldToVWWithBF16EXT);
+ Strategies.push_back(canFoldToVWWithSameExtBF16);
break;
case ISD::MUL:
case RISCVISD::MUL_VL:
@@ -17245,7 +17257,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::SHL:
case RISCVISD::SHL_VL:
// shl -> vwsll
- Strategies.push_back(canFoldToVWWithZEXT);
+ Strategies.push_back(canFoldToVWWithSameExtZEXT);
break;
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWSUB_W_VL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
index 227a428831b60..ea4add2da5ebc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
@@ -58,3 +58,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) {
%i = or <2 x i16> %h, %g
ret <2 x i16> %i
}
+
+; Make sure we have a vsext.vl and a vwaddu.vx.
+define <4 x i32> @pr159152(<4 x i8> %x) {
+; NO_FOLDING-LABEL: pr159152:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; NO_FOLDING-NEXT: vsext.vf2 v9, v8
+; NO_FOLDING-NEXT: li a0, 9
+; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: pr159152:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; FOLDING-NEXT: vsext.vf2 v9, v8
+; FOLDING-NEXT: li a0, 9
+; FOLDING-NEXT: vwaddu.vx v8, v9, a0
+; FOLDING-NEXT: ret
+ %a = sext <4 x i8> %x to <4 x i16>
+ %b = zext <4 x i16> %a to <4 x i32>
+ %c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9>
+ ret <4 x i32> %c
+}
More information about the llvm-branch-commits
mailing list