[llvm] [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl with zvfbfwma (PR #108798)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 16 00:23:15 PDT 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/108798

We currently make sure to check that if widening a f16 vector that we have zvfh. We need to do the same for bf16 vectors, but with the further restriction that we can only combine vfmadd_vl to vfwmadd_vl (to get vfwmaccbf16.v{v,f}).

This moves the checks into the extension support checks to keep it one place.


>From d3b112857eb04ef97cff41edf8f2f5cf670df5ce Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 16 Sep 2024 15:13:00 +0800
Subject: [PATCH] [RISCV] Restrict combineOp_VLToVWOp_VL w/ bf16 to vfwmadd_vl
 with zvfbfwma

We currently make sure to check that if widening a f16 vector that we have zvfh. We need to do the same for bf16 vectors, but with the further restriction that we can only combine vfmadd_vl to vfwmadd_vl (to get vfwmaccbf16.v{v,f}).

This moves the checks into the extension support checks to keep it one place.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 34 +++++++++++++++------
 llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll    | 20 ++++++++++--
 2 files changed, 42 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7d2a7b20ba2508..eb8ea95e2d8583 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14740,6 +14740,19 @@ struct NodeExtensionHelper {
     EnforceOneUse = false;
   }
 
+  bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
+                           const RISCVSubtarget &Subtarget) {
+    // Any f16 extension will neeed zvfh
+    if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
+      return false;
+    // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
+    // zvfbfwma
+    if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
+                                     Root->getOpcode() != RISCVISD::VFMADD_VL))
+      return false;
+    return true;
+  }
+
   /// Helper method to set the various fields of this struct based on the
   /// type of \p Root.
   void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -14775,9 +14788,14 @@ struct NodeExtensionHelper {
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
       break;
-    case RISCVISD::FP_EXTEND_VL:
+    case RISCVISD::FP_EXTEND_VL: {
+      MVT NarrowEltVT =
+          OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
+      if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
+        break;
       SupportsFPExt = true;
       break;
+    }
     case ISD::SPLAT_VECTOR:
     case RISCVISD::VMV_V_X_VL:
       fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -14792,6 +14810,10 @@ struct NodeExtensionHelper {
       if (Op.getOpcode() != ISD::FP_EXTEND)
         break;
 
+      if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
+                               Subtarget))
+        break;
+
       unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
       unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
       if (NarrowSize != ScalarBits)
@@ -15774,10 +15796,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N,
   if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
     return V;
 
-  if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
-      !Subtarget.hasVInstructionsF16() && !Subtarget.hasStdExtZvfbfwma())
-    return SDValue();
-
   // FIXME: Ignore strict opcodes for now.
   if (N->isTargetStrictFPOpcode())
     return SDValue();
@@ -17522,12 +17540,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::FSUB_VL:
   case RISCVISD::FMUL_VL:
   case RISCVISD::VFWADD_W_VL:
-  case RISCVISD::VFWSUB_W_VL: {
-    if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
-        !Subtarget.hasVInstructionsF16())
-      return SDValue();
+  case RISCVISD::VFWSUB_W_VL:
     return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
-  }
   case ISD::LOAD:
   case ISD::STORE: {
     if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
index 1ef0ed858d80a9..d8a9ab9c3937ec 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
@@ -1,6 +1,22 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh | FileCheck %s --check-prefixes=ZVFH
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin | FileCheck %s --check-prefixes=ZVFHMIN
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFH
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfhmin,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFHMIN
+
+define <vscale x 2 x float> @vfwadd_same_operand_nxv2bf16(<vscale x 2 x bfloat> %arg, i32 signext %vl) {
+; CHECK-LABEL: vfwadd_same_operand_nxv2bf16:
+; CHECK:       # %bb.0: # %bb
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vfwcvtbf16.f.f.v v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vfadd.vv v8, v9, v9
+; CHECK-NEXT:    ret
+bb:
+  %tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2bf16(<vscale x 2 x bfloat> %arg, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+  %tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> splat (i1 true), i32 %vl)
+  ret <vscale x 2 x float> %tmp2
+}
 
 define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
 ; ZVFH-LABEL: vfwadd_same_operand:



More information about the llvm-commits mailing list