[llvm] 6af2f22 - [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl with zvfbfwma (#108798)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 16 22:35:30 PDT 2024
Author: Luke Lau
Date: 2024-09-17T13:35:25+08:00
New Revision: 6af2f225a0f820d331f251af69c2dad0c845964e
URL: https://github.com/llvm/llvm-project/commit/6af2f225a0f820d331f251af69c2dad0c845964e
DIFF: https://github.com/llvm/llvm-project/commit/6af2f225a0f820d331f251af69c2dad0c845964e.diff
LOG: [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl with zvfbfwma (#108798)
We currently make sure to check that if folding an op to an f16 widening
op that we have zvfh. We need to do the same for bf16 vectors, but with
the further restriction that we can only combine vfmadd_vl to vfwmadd_vl
(to get vfwmaccbf16.v{v,f}).
The added test case currently crashes because we try to fold an add to a
bf16 widening add, which doesn't exist in zvfbfmin or zvfbfwma
This moves the checks into the extension support checks to keep it one
place.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7d2a7b20ba2508..eb8ea95e2d8583 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14740,6 +14740,19 @@ struct NodeExtensionHelper {
EnforceOneUse = false;
}
+ bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
+ const RISCVSubtarget &Subtarget) {
+ // Any f16 extension will neeed zvfh
+ if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
+ return false;
+ // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
+ // zvfbfwma
+ if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
+ Root->getOpcode() != RISCVISD::VFMADD_VL))
+ return false;
+ return true;
+ }
+
/// Helper method to set the various fields of this struct based on the
/// type of \p Root.
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -14775,9 +14788,14 @@ struct NodeExtensionHelper {
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
break;
- case RISCVISD::FP_EXTEND_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ MVT NarrowEltVT =
+ OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
+ if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
+ break;
SupportsFPExt = true;
break;
+ }
case ISD::SPLAT_VECTOR:
case RISCVISD::VMV_V_X_VL:
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -14792,6 +14810,10 @@ struct NodeExtensionHelper {
if (Op.getOpcode() != ISD::FP_EXTEND)
break;
+ if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
+ Subtarget))
+ break;
+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
if (NarrowSize != ScalarBits)
@@ -15774,10 +15796,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
return V;
- if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
- !Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
- return SDValue();
-
// FIXME: Ignore strict opcodes for now.
if (N->isTargetStrictFPOpcode())
return SDValue();
@@ -17522,12 +17540,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::FSUB_VL:
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
- case RISCVISD::VFWSUB_W_VL: {
- if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
- !Subtarget.hasVInstructionsF16())
- return SDValue();
+ case RISCVISD::VFWSUB_W_VL:
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
- }
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
index 1ef0ed858d80a9..f7297927db7174 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
@@ -1,6 +1,44 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh | FileCheck %s --check-prefixes=ZVFH
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin | FileCheck %s --check-prefixes=ZVFHMIN
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFH
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFHMIN
+
+define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16(<vscale x 2 x bfloat> %arg, i32 signext %vl) {
+; CHECK-LABEL: vfwadd_same_operand_nxv2bf16:
+; CHECK: # %bb.0: # %bb
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vfadd.vv v8, v9, v9
+; CHECK-NEXT: ret
+bb:
+ %tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ %tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ ret <vscale x 2 x float> %tmp2
+}
+
+; Make sure we don't widen vfmadd.vv -> vfwmaccvbf16.vv if there's other
+; unwidenable uses
+define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16_multiuse(<vscale x 2 x bfloat> %arg, <vscale x 2 x float> %acc, i32 signext %vl, ptr %p) {
+; CHECK-LABEL: vfwadd_same_operand_nxv2bf16_multiuse:
+; CHECK: # %bb.0: # %bb
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vfadd.vv v8, v10, v10
+; CHECK-NEXT: vfmadd.vv v10, v10, v9
+; CHECK-NEXT: vs1r.v v10, (a1)
+; CHECK-NEXT: ret
+bb:
+ %tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ %tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ %tmp3 = call <vscale x 2 x float> @llvm.vp.fma.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x float> %acc, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ store <vscale x 2 x float> %tmp3, ptr %p
+ ret <vscale x 2 x float> %tmp2
+}
define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
; ZVFH-LABEL: vfwadd_same_operand:
More information about the llvm-commits
mailing list