[llvm] [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl with zvfbfwma (PR #108798)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 16 00:23:15 PDT 2024
https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/108798
We currently make sure to check that if widening a f16 vector that we have zvfh. We need to do the same for bf16 vectors, but with the further restriction that we can only combine vfmadd_vl to vfwmadd_vl (to get vfwmaccbf16.v{v,f}).
This moves the checks into the extension support checks to keep it one place.
>From d3b112857eb04ef97cff41edf8f2f5cf670df5ce Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 16 Sep 2024 15:13:00 +0800
Subject: [PATCH] [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl
with zvfbfwma
We currently make sure to check that if widening a f16 vector that we have zvfh. We need to do the same for bf16 vectors, but with the further restriction that we can only combine vfmadd_vl to vfwmadd_vl (to get vfwmaccbf16.v{v,f}).
This moves the checks into the extension support checks to keep it one place.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 34 +++++++++++++++------
llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll | 20 ++++++++++--
2 files changed, 42 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7d2a7b20ba2508..eb8ea95e2d8583 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14740,6 +14740,19 @@ struct NodeExtensionHelper {
EnforceOneUse = false;
}
+ bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
+ const RISCVSubtarget &Subtarget) {
+ // Any f16 extension will neeed zvfh
+ if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
+ return false;
+ // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
+ // zvfbfwma
+ if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
+ Root->getOpcode() != RISCVISD::VFMADD_VL))
+ return false;
+ return true;
+ }
+
/// Helper method to set the various fields of this struct based on the
/// type of \p Root.
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -14775,9 +14788,14 @@ struct NodeExtensionHelper {
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
break;
- case RISCVISD::FP_EXTEND_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ MVT NarrowEltVT =
+ OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
+ if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
+ break;
SupportsFPExt = true;
break;
+ }
case ISD::SPLAT_VECTOR:
case RISCVISD::VMV_V_X_VL:
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -14792,6 +14810,10 @@ struct NodeExtensionHelper {
if (Op.getOpcode() != ISD::FP_EXTEND)
break;
+ if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
+ Subtarget))
+ break;
+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
if (NarrowSize != ScalarBits)
@@ -15774,10 +15796,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
return V;
- if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
- !Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
- return SDValue();
-
// FIXME: Ignore strict opcodes for now.
if (N->isTargetStrictFPOpcode())
return SDValue();
@@ -17522,12 +17540,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::FSUB_VL:
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
- case RISCVISD::VFWSUB_W_VL: {
- if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
- !Subtarget.hasVInstructionsF16())
- return SDValue();
+ case RISCVISD::VFWSUB_W_VL:
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
- }
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
index 1ef0ed858d80a9..d8a9ab9c3937ec 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
@@ -1,6 +1,22 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh | FileCheck %s --check-prefixes=ZVFH
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin | FileCheck %s --check-prefixes=ZVFHMIN
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFH
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFHMIN
+
+define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16(<vscale x 2 x bfloat> %arg, i32 signext %vl) {
+; CHECK-LABEL: vfwadd_same_operand_nxv2bf16:
+; CHECK: # %bb.0: # %bb
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vfadd.vv v8, v9, v9
+; CHECK-NEXT: ret
+bb:
+ %tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ %tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+ ret <vscale x 2 x float> %tmp2
+}
define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
; ZVFH-LABEL: vfwadd_same_operand:
More information about the llvm-commits
mailing list