[llvm] [RISCV][WIP] Treat bf16->f32 as separate ExtKind in combineOp_VLToVWOp_VL. (PR #144653)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 18 12:49:44 PDT 2025


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/144653

>From c2f0ce3b14e466eeb22ce15d32ee85d1f573fb77 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 18 Jun 2025 01:50:43 -0700
Subject: [PATCH 1/2] [RISCV][WIP] Treat bf16->f32 as separate ExtKind in
 combineOp_VLToVWOp_VL.

This allows us to better track the narrow type we need and to fix
miscompiles if f16->f32 and bf16->f32 extends are mixed.

Fixes #144651.

Still need to add tests, but it's late and I need sleep.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 73 ++++++++++++---------
 1 file changed, 42 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e670567bd1844..f7d447e03af94 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16309,7 +16309,12 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
-enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
+enum ExtKind : uint8_t {
+  ZExt = 1 << 0,
+  SExt = 1 << 1,
+  FPExt = 1 << 2,
+  BF16Ext = 1 << 3
+};
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16344,8 +16349,10 @@ struct NodeExtensionHelper {
   /// instance, a splat constant (e.g., 3), would support being both sign and
   /// zero extended.
   bool SupportsSExt;
-  /// Records if this operand is like being floating-Point extended.
+  /// Records if this operand is like being floating point extended.
   bool SupportsFPExt;
+  /// Records if this operand is extended from bf16.
+  bool SupportsBF16Ext;
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
@@ -16381,6 +16388,7 @@ struct NodeExtensionHelper {
     case ExtKind::ZExt:
       return RISCVISD::VZEXT_VL;
     case ExtKind::FPExt:
+    case ExtKind::BF16Ext:
       return RISCVISD::FP_EXTEND_VL;
     }
     llvm_unreachable("Unknown ExtKind enum");
@@ -16402,13 +16410,6 @@ struct NodeExtensionHelper {
     if (Source.getValueType() == NarrowVT)
       return Source;
 
-    // vfmadd_vl -> vfwmadd_vl can take bf16 operands
-    if (Source.getValueType().getVectorElementType() == MVT::bf16) {
-      assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
-             Root->getOpcode() == RISCVISD::VFMADD_VL);
-      return Source;
-    }
-
     unsigned ExtOpc = getExtOpc(*SupportsExt);
 
     // If we need an extension, we should be changing the type.
@@ -16451,7 +16452,8 @@ struct NodeExtensionHelper {
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
 
-    MVT EltVT = SupportsExt == ExtKind::FPExt
+    MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
+                : SupportsExt == ExtKind::FPExt
                     ? MVT::getFloatingPointVT(NarrowSize)
                     : MVT::getIntegerVT(NarrowSize);
 
@@ -16628,17 +16630,17 @@ struct NodeExtensionHelper {
     EnforceOneUse = false;
   }
 
-  bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
-                           const RISCVSubtarget &Subtarget) {
+  bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+    if (NarrowEltVT == MVT::f32)
+      return true;
     // Any f16 extension will need zvfh
-    if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
-      return false;
-    // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
-    // zvfbfwma
-    if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
-                                     Root->getOpcode() != RISCVISD::VFMADD_VL))
-      return false;
-    return true;
+    if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())
+      return true;
+    return false;
+  }
+
+  bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+    return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
   }
 
   /// Helper method to set the various fields of this struct based on the
@@ -16648,6 +16650,7 @@ struct NodeExtensionHelper {
     SupportsZExt = false;
     SupportsSExt = false;
     SupportsFPExt = false;
+    SupportsBF16Ext = false;
     EnforceOneUse = true;
     unsigned Opc = OrigOperand.getOpcode();
     // For the nodes we handle below, we end up using their inputs directly: see
@@ -16679,9 +16682,11 @@ struct NodeExtensionHelper {
     case RISCVISD::FP_EXTEND_VL: {
       MVT NarrowEltVT =
           OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
-      if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
-        break;
-      SupportsFPExt = true;
+      if (isSupportedFPExtend(NarrowEltVT, Subtarget))
+        SupportsFPExt = true;
+      if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
+        SupportsBF16Ext = true;
+
       break;
     }
     case ISD::SPLAT_VECTOR:
@@ -16698,16 +16703,16 @@ struct NodeExtensionHelper {
       if (Op.getOpcode() != ISD::FP_EXTEND)
         break;
 
-      if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
-                               Subtarget))
-        break;
-
       unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
       unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
       if (NarrowSize != ScalarBits)
         break;
 
-      SupportsFPExt = true;
+      if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
+        SupportsFPExt = true;
+      if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
+                                Subtarget))
+        SupportsBF16Ext = true;
       break;
     }
     default:
@@ -16940,6 +16945,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
     return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
                          /*RHSExt=*/{ExtKind::FPExt});
+  if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
+      RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL)
+    return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+                         Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
+                         /*RHSExt=*/{ExtKind::BF16Ext});
   return std::nullopt;
 }
 
@@ -16953,9 +16963,10 @@ static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(
-      Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
-      Subtarget);
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
+                                          ExtKind::ZExt | ExtKind::SExt |
+                                              ExtKind::FPExt | ExtKind::BF16Ext,
+                                          DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))

>From aafd72ef88b0386056a3e7b00819460c93220b95 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 18 Jun 2025 12:15:34 -0700
Subject: [PATCH 2/2] fixup! add tests

---
 .../RISCV/rvv/fixed-vectors-vfwmaccbf16.ll    | 58 +++++++++++++++++--
 1 file changed, 54 insertions(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
index 1639f21f243d8..aec970adff51e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
-; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
 
 define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) {
 ; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32:
@@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl
   %res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a)
   ret <32 x float> %res
 }
+
+define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
+; ZVFBFWMA:       # %bb.0:
+; ZVFBFWMA-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT:    vfwmaccbf16.vf v8, fa0, v9
+; ZVFBFWMA-NEXT:    ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
+; ZVFBFMIN:       # %bb.0:
+; ZVFBFMIN-NEXT:    fmv.x.w a0, fa0
+; ZVFBFMIN-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT:    vfwcvtbf16.f.f.v v10, v9
+; ZVFBFMIN-NEXT:    slli a0, a0, 16
+; ZVFBFMIN-NEXT:    fmv.w.x fa5, a0
+; ZVFBFMIN-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT:    vfmacc.vf v8, fa5, v10
+; ZVFBFMIN-NEXT:    ret
+  %b_ext = fpext <4 x bfloat> %b to <4 x float>
+  %a_extend = fpext bfloat %a to float
+  %a_insert = insertelement <4 x float> poison, float %a_extend, i64 0
+  %a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer
+  %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd)
+  ret <4 x float> %fma
+}
+
+; Negative test with a mix of bfloat and half fpext.
+define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) {
+; ZVFBFWMA-LABEL: mix:
+; ZVFBFWMA:       # %bb.0:
+; ZVFBFWMA-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT:    vfwcvt.f.f.v v11, v9
+; ZVFBFWMA-NEXT:    vfwcvtbf16.f.f.v v9, v10
+; ZVFBFWMA-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFWMA-NEXT:    vfmacc.vv v8, v11, v9
+; ZVFBFWMA-NEXT:    ret
+;
+; ZVFBFMIN-LABEL: mix:
+; ZVFBFMIN:       # %bb.0:
+; ZVFBFMIN-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT:    vfwcvt.f.f.v v11, v9
+; ZVFBFMIN-NEXT:    vfwcvtbf16.f.f.v v9, v10
+; ZVFBFMIN-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT:    vfmacc.vv v8, v11, v9
+; ZVFBFMIN-NEXT:    ret
+  %a_ext = fpext <4 x half> %a to <4 x float>
+  %b_ext = fpext <4 x bfloat> %b to <4 x float>
+  %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd)
+  ret <4 x float> %fma
+}



More information about the llvm-commits mailing list