[llvm] [RISCV] Support Inline ASM for the bf16 type. (PR #80118)

Chuan-Yue Yuan via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 23:02:14 PDT 2024


https://github.com/circYuan updated https://github.com/llvm/llvm-project/pull/80118

>From daa224ad73986ca8788fbe4c988132ce07e9be24 Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Tue, 28 May 2024 13:56:15 +0800
Subject: [PATCH 1/3] [RISCV] Support Inline ASM for the bf16 type.

This patch makes the RISCV-V asm constraint `f` recognize the bfloat
type.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 16 +++-
 .../RISCV/inline-asm-f-constraint-bf16.ll     | 77 +++++++++++++++++++
 2 files changed, 92 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..f5708f62661e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20155,7 +20155,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         return std::make_pair(0U, &RISCV::GPRPairRegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
-      if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
+      if ((Subtarget.hasStdExtZfhmin() && VT == MVT::f16) ||
+          (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16))
         return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtF() && VT == MVT::f32)
         return std::make_pair(0U, &RISCV::FPR32RegClass);
@@ -20273,6 +20274,11 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);
       }
+      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16){
+        unsigned RegNo = FReg - RISCV::F0_F;
+        unsigned HReg = RISCV::F0_H + RegNo;
+        return std::make_pair(HReg, &RISCV::FPR16RegClass);
+      }
     }
   }
 
@@ -20949,6 +20955,14 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
     return true;
   }
 
+  // Since the inline asm only use the first type in the RegisterClass, the bf16
+  // inline asm would choose the f16 from FPR16RegClass for doing the copy, and
+  // we correct the behavior here to avoid generating wrong SelectionDAG.
+  if (ValueVT == MVT::bf16 && PartVT == MVT::f16){
+    Parts[0] = Val;
+    return true;
+  }
+
   if (ValueVT.isScalableVector() && PartVT.isScalableVector()) {
     LLVMContext &Context = *DAG.getContext();
     EVT ValueEltVT = ValueVT.getVectorElementType();
diff --git a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
new file mode 100644
index 0000000000000..0fec5088192b5
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+
+ at gf = external global float
+
+define float @constraint_f_float(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa5, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, fa5
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa5, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, fa5
+; RV64F-NEXT:    ret
+  %1 = load float, float* @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "=f,f"(bfloat %a)
+  ret float %2
+}
+
+define float @constraint_f_float_abi_name(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float_abi_name:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa0, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, ft0
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float_abi_name:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa0, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, ft0
+; RV64F-NEXT:    ret
+  %1 = load float, float* @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "={ft0},{fa0}"(bfloat %a)
+  ret float %2
+}
+
+define bfloat @constraint_gpr(bfloat %x) {
+; RV32F-LABEL: constraint_gpr:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    .cfi_def_cfa_offset 0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    mv a0, a0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_gpr:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    .cfi_def_cfa_offset 0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    mv a0, a0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    ret
+  %1 = tail call bfloat asm sideeffect alignstack "mv $0, $1", "={x10},{x10}"(bfloat %x)
+  ret bfloat %1
+}

>From e85a835ed30dae02e4963960cc1822c93df6edb1 Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Tue, 28 May 2024 14:02:36 +0800
Subject: [PATCH 2/3] Fix typo errors and merge the redundant if condition.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f5708f62661e7..3f918615bf0cc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20269,12 +20269,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       }
       if (VT == MVT::f32 || VT == MVT::Other)
         return std::make_pair(FReg, &RISCV::FPR32RegClass);
-      if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) {
-        unsigned RegNo = FReg - RISCV::F0_F;
-        unsigned HReg = RISCV::F0_H + RegNo;
-        return std::make_pair(HReg, &RISCV::FPR16RegClass);
-      }
-      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16){
+      if ((Subtarget.hasStdExtZfhmin() && VT == MVT::f16) ||
+          (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16)) {
         unsigned RegNo = FReg - RISCV::F0_F;
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);
@@ -20958,7 +20954,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
   // Since the inline asm only use the first type in the RegisterClass, the bf16
   // inline asm would choose the f16 from FPR16RegClass for doing the copy, and
   // we correct the behavior here to avoid generating wrong SelectionDAG.
-  if (ValueVT == MVT::bf16 && PartVT == MVT::f16){
+  if (ValueVT == MVT::bf16 && PartVT == MVT::f16) {
     Parts[0] = Val;
     return true;
   }

>From 7c827d2c54d752c18549c17ba88bd703b294c0a0 Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Fri, 7 Jun 2024 13:55:40 +0800
Subject: [PATCH 3/3] Add the test for testing the return value which type is
 bf16.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 28 ++++++++++-
 .../RISCV/inline-asm-f-constraint-bf16.ll     | 48 +++++++++++++------
 2 files changed, 61 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3f918615bf0cc..bf8476e032eea 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -479,6 +479,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     // FIXME: Need to promote bf16 FCOPYSIGN to f32, but the
     // DAGCombiner::visitFP_ROUND probably needs improvements first.
     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // To fold (bf16 bitcast (copyfromreg f16)) -> (copyfromreg bf16), we have
+    // to legalize the f16 CopyFromReg for avoiding SoftPromoteHalf.
+    setOperationAction(ISD::CopyFromReg, MVT::f16, Legal);
+    // Fold the (bf16 bitcast (copyfromreg f16)) -> (copyfromreg bf16).
+    setTargetDAGCombine(ISD::BITCAST);
   }
 
   if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
@@ -17088,10 +17094,30 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
   }
   case ISD::BITCAST: {
-    assert(Subtarget.useRVVForFixedLengthVectors());
+    assert(Subtarget.useRVVForFixedLengthVectors() ||
+           Subtarget.hasStdExtZfbfmin());
     SDValue N0 = N->getOperand(0);
     EVT VT = N->getValueType(0);
     EVT SrcVT = N0.getValueType();
+
+    // Fold the (bf16 bitcast (copyfromreg f16)) -> (copyfromreg bf16).
+    if (SrcVT == MVT::f16 && VT == MVT::bf16) {
+      if (N0.getOpcode() == ISD::CopyFromReg) {
+        SDValue F16CopyFromReg = N0->getOperand(1);
+        Register BFReg = cast<RegisterSDNode>(F16CopyFromReg)->getReg();
+        SDValue Chain = N0->getOperand(0);
+        SDValue NewCopy;
+        if (F16CopyFromReg.getNumOperands() == 3) {
+          SDValue Glue = N0->getOperand(2);
+          NewCopy = DAG.getCopyFromReg(Chain, DL, BFReg, MVT::bf16, Glue);
+        } else {
+          NewCopy = DAG.getCopyFromReg(Chain, DL, BFReg, MVT::bf16);
+        }
+        DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewCopy.getValue(1));
+        return NewCopy;
+      }
+      return SDValue();
+    }
     // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
     // type, widen both sides to avoid a trip through memory.
     if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) &&
diff --git a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
index 0fec5088192b5..95e3d0bf5c04d 100644
--- a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
+++ b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
@@ -1,17 +1,15 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 < %s \
 ; RUN:   | FileCheck -check-prefix=RV32F %s
-; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 < %s \
 ; RUN:   | FileCheck -check-prefix=RV64F %s
-; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 < %s \
 ; RUN:   | FileCheck -check-prefix=RV32F %s
-; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 < %s \
 ; RUN:   | FileCheck -check-prefix=RV64F %s
 
- at gf = external global float
-
-define float @constraint_f_float(bfloat %a) nounwind {
-; RV32F-LABEL: constraint_f_float:
+define float @constraint_f_bfloat(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_bfloat:
 ; RV32F:       # %bb.0:
 ; RV32F-NEXT:    fmv.h.x fa5, a0
 ; RV32F-NEXT:    #APP
@@ -20,7 +18,7 @@ define float @constraint_f_float(bfloat %a) nounwind {
 ; RV32F-NEXT:    fmv.x.w a0, fa5
 ; RV32F-NEXT:    ret
 ;
-; RV64F-LABEL: constraint_f_float:
+; RV64F-LABEL: constraint_f_bfloat:
 ; RV64F:       # %bb.0:
 ; RV64F-NEXT:    fmv.h.x fa5, a0
 ; RV64F-NEXT:    #APP
@@ -28,13 +26,36 @@ define float @constraint_f_float(bfloat %a) nounwind {
 ; RV64F-NEXT:    #NO_APP
 ; RV64F-NEXT:    fmv.x.w a0, fa5
 ; RV64F-NEXT:    ret
-  %1 = load float, float* @gf
   %2 = tail call float asm "fcvt.s.bf16 $0, $1", "=f,f"(bfloat %a)
   ret float %2
 }
 
-define float @constraint_f_float_abi_name(bfloat %a) nounwind {
-; RV32F-LABEL: constraint_f_float_abi_name:
+define bfloat @constraint_bfloat_f(float %x) {
+; RV32F-LABEL: constraint_bfloat_f:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    .cfi_def_cfa_offset 0
+; RV32F-NEXT:    fmv.w.x fa5, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.bf16.s fa5, fa5
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.h a0, fa5
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_bfloat_f:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    .cfi_def_cfa_offset 0
+; RV64F-NEXT:    fmv.w.x fa5, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.bf16.s fa5, fa5
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.h a0, fa5
+; RV64F-NEXT:    ret
+  %1 = tail call bfloat asm sideeffect alignstack "fcvt.bf16.s $0, $1", "=f,f"(float %x)
+  ret bfloat %1
+}
+
+define float @constraint_f_bfloat_abi_name(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_bfloat_abi_name:
 ; RV32F:       # %bb.0:
 ; RV32F-NEXT:    fmv.h.x fa0, a0
 ; RV32F-NEXT:    #APP
@@ -43,7 +64,7 @@ define float @constraint_f_float_abi_name(bfloat %a) nounwind {
 ; RV32F-NEXT:    fmv.x.w a0, ft0
 ; RV32F-NEXT:    ret
 ;
-; RV64F-LABEL: constraint_f_float_abi_name:
+; RV64F-LABEL: constraint_f_bfloat_abi_name:
 ; RV64F:       # %bb.0:
 ; RV64F-NEXT:    fmv.h.x fa0, a0
 ; RV64F-NEXT:    #APP
@@ -51,7 +72,6 @@ define float @constraint_f_float_abi_name(bfloat %a) nounwind {
 ; RV64F-NEXT:    #NO_APP
 ; RV64F-NEXT:    fmv.x.w a0, ft0
 ; RV64F-NEXT:    ret
-  %1 = load float, float* @gf
   %2 = tail call float asm "fcvt.s.bf16 $0, $1", "={ft0},{fa0}"(bfloat %a)
   ret float %2
 }



More information about the llvm-commits mailing list