[llvm] [RISCV] Support Inline ASM for the bf16 type. (PR #80118)

Chuan-Yue Yuan via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 02:17:10 PST 2024


https://github.com/circYuan updated https://github.com/llvm/llvm-project/pull/80118

>From 60dfccddb3b7003355369221726fbacbc0fd89d8 Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Wed, 31 Jan 2024 16:35:04 +0800
Subject: [PATCH 1/3] [LLVM][RISCV][BF16] Support Inline ASM for the bf16 type.

This patch makes the RISCV-V asm constraint `f` recognize the bfloat
type.
---
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 13 +++-
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  7 ++
 .../RISCV/inline-asm-f-constraint-bf16.ll     | 77 +++++++++++++++++++
 3 files changed, 96 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5ce1013f30fd1..7342e7bcba1f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -9105,7 +9105,7 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
   // Get the actual register value type.  This is important, because the user
   // may have asked for (e.g.) the AX register in i32 type.  We need to
   // remember that AX is actually i16 to get the right extension.
-  const MVT RegVT = *TRI.legalclasstypes_begin(*RC);
+  MVT RegVT = *TRI.legalclasstypes_begin(*RC);
 
   if (OpInfo.ConstraintVT != MVT::Other && RegVT != MVT::Untyped) {
     // If this is an FP operand in an integer register (or visa versa), or more
@@ -9139,6 +9139,17 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
               DAG.getNode(ISD::BITCAST, DL, VT, OpInfo.CallOperand);
         OpInfo.ConstraintVT = VT;
       }
+      // If the RegisterClass contains more than one types like RISCV
+      // FPR16RegClass which has [f16, bf16], We should check if the
+      // OpInfo.ConstraintVT can directly be assigned to the RegVT.
+    } else if ((OpInfo.Type == InlineAsm::isOutput ||
+                OpInfo.Type == InlineAsm::isInput) &&
+               TRI.isTypeLegalForClass(*RC, OpInfo.ConstraintVT)) {
+      if (RegVT != OpInfo.ConstraintVT &&
+          RegVT.getSizeInBits() == OpInfo.ConstraintVT.getSizeInBits() &&
+          RegVT.isFloatingPoint() && OpInfo.ConstraintVT.isFloatingPoint()) {
+        RegVT = OpInfo.ConstraintVT;
+      }
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b8994e7b7bdb2..7a6e41ab7fee3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19225,6 +19225,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         return std::make_pair(0U, &RISCV::GPRPairRegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
+      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16)
+        return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
         return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtF() && VT == MVT::f32)
@@ -19343,6 +19345,11 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);
       }
+      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16) {
+        unsigned RegNo = FReg - RISCV::F0_F;
+        unsigned HReg = RISCV::F0_H + RegNo;
+        return std::make_pair(HReg, &RISCV::FPR16RegClass);
+      }
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
new file mode 100644
index 0000000000000..a496e2fea173e
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+
+ at gf = external global float
+
+define float @constraint_f_float(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa5, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, fa5
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa5, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, fa5
+; RV64F-NEXT:    ret
+  %1 = load float, ptr @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "=f,f"(bfloat %a)
+  ret float %2
+}
+
+define float @constraint_f_float_abi_name(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float_abi_name:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa0, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, ft0
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float_abi_name:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa0, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, ft0
+; RV64F-NEXT:    ret
+  %1 = load float, ptr @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "={ft0},{fa0}"(bfloat %a)
+  ret float %2
+}
+
+define bfloat @constraint_gpr(bfloat %x) {
+; RV32F-LABEL: constraint_gpr:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    .cfi_def_cfa_offset 0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    mv a0, a0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_gpr:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    .cfi_def_cfa_offset 0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    mv a0, a0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    ret
+  %1 = tail call bfloat asm sideeffect alignstack "mv $0, $1", "={x10},{x10}"(bfloat %x)
+  ret bfloat %1
+}

>From 3f5a575f074ede400cf9f9341fa660e54cdec310 Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Wed, 31 Jan 2024 16:56:44 +0800
Subject: [PATCH 2/3] Merge the indentical code block

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7a6e41ab7fee3..c2a1617a3d86e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19340,12 +19340,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       }
       if (VT == MVT::f32 || VT == MVT::Other)
         return std::make_pair(FReg, &RISCV::FPR32RegClass);
-      if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) {
-        unsigned RegNo = FReg - RISCV::F0_F;
-        unsigned HReg = RISCV::F0_H + RegNo;
-        return std::make_pair(HReg, &RISCV::FPR16RegClass);
-      }
-      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16) {
+      if ((Subtarget.hasStdExtZfhmin() && VT == MVT::f16) ||
+          (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16)) {
         unsigned RegNo = FReg - RISCV::F0_F;
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);

>From eabaea0f84fbe0e1495e043f1a894590d9303e0b Mon Sep 17 00:00:00 2001
From: Tony Chuan-Yue Yuan <yuan593 at andestech.com>
Date: Mon, 5 Feb 2024 18:02:34 +0800
Subject: [PATCH 3/3] Refactor the generic code.

Make each target has the opportunity to choose the proper value type
when generating the Inline Asm SelectionDAG. In the RISCV case, we do
not expect ConstraintVT to be f16 but to receive the bf16 type, as this
would cause the SelectionDAG to create an incorrect DAG, such as
t1: bf16 = soft_promote t0, where t0 is fp16.
---
 llvm/include/llvm/CodeGen/TargetLowering.h         |  9 +++++++++
 .../CodeGen/SelectionDAG/SelectionDAGBuilder.cpp   | 14 ++------------
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp        | 12 ++++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h          |  4 ++++
 4 files changed, 27 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index d39094aa7fed7..f17f514b9815d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4882,6 +4882,15 @@ class TargetLowering : public TargetLoweringBase {
   getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
                                StringRef Constraint, MVT VT) const;
 
+  /// Given the TargetRegisterClass and the ConstraintValueType, make each
+  /// target choose its own proper value type for inline asm; otherwise, use
+  /// the first type in the TargetRegisterClass.
+  virtual MVT getRegVTFromInlineAsmConstraint(const TargetRegisterInfo *TRI,
+                                              const TargetRegisterClass *RC,
+                                              const MVT ConstraintVT) const {
+    return *TRI->legalclasstypes_begin(*RC);
+  }
+
   virtual InlineAsm::ConstraintCode
   getInlineAsmMemConstraint(StringRef ConstraintCode) const {
     if (ConstraintCode == "m")
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 7342e7bcba1f2..1a62fe142f58f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -9105,7 +9105,8 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
   // Get the actual register value type.  This is important, because the user
   // may have asked for (e.g.) the AX register in i32 type.  We need to
   // remember that AX is actually i16 to get the right extension.
-  MVT RegVT = *TRI.legalclasstypes_begin(*RC);
+  const MVT RegVT =
+      TLI.getRegVTFromInlineAsmConstraint(&TRI, RC, OpInfo.ConstraintVT);
 
   if (OpInfo.ConstraintVT != MVT::Other && RegVT != MVT::Untyped) {
     // If this is an FP operand in an integer register (or visa versa), or more
@@ -9139,17 +9140,6 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
               DAG.getNode(ISD::BITCAST, DL, VT, OpInfo.CallOperand);
         OpInfo.ConstraintVT = VT;
       }
-      // If the RegisterClass contains more than one types like RISCV
-      // FPR16RegClass which has [f16, bf16], We should check if the
-      // OpInfo.ConstraintVT can directly be assigned to the RegVT.
-    } else if ((OpInfo.Type == InlineAsm::isOutput ||
-                OpInfo.Type == InlineAsm::isInput) &&
-               TRI.isTypeLegalForClass(*RC, OpInfo.ConstraintVT)) {
-      if (RegVT != OpInfo.ConstraintVT &&
-          RegVT.getSizeInBits() == OpInfo.ConstraintVT.getSizeInBits() &&
-          RegVT.isFloatingPoint() && OpInfo.ConstraintVT.isFloatingPoint()) {
-        RegVT = OpInfo.ConstraintVT;
-      }
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2a1617a3d86e..7b656573aa617 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19413,6 +19413,18 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   return Res;
 }
 
+MVT RISCVTargetLowering::getRegVTFromInlineAsmConstraint(
+    const TargetRegisterInfo *TRI, const TargetRegisterClass *RC,
+    const MVT ConstraintVT) const {
+  // FPR16RegClass contains [f16, bf16], and when the ConstraintVT is bf16, we
+  // should choose the bf16 instead of the first type, which is f16.
+  if (ConstraintVT == MVT::bf16 && RC == &RISCV::FPR16RegClass) {
+    return MVT::bf16;
+  }
+
+  return TargetLowering::getRegVTFromInlineAsmConstraint(TRI, RC, ConstraintVT);
+}
+
 InlineAsm::ConstraintCode
 RISCVTargetLowering::getInlineAsmMemConstraint(StringRef ConstraintCode) const {
   // Currently only support length 1 constraints.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 255b1d0e15eed..0fd57aa59aa84 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -608,6 +608,10 @@ class RISCVTargetLowering : public TargetLowering {
   getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
                                StringRef Constraint, MVT VT) const override;
 
+  MVT getRegVTFromInlineAsmConstraint(const TargetRegisterInfo *TRI,
+                                      const TargetRegisterClass *RC,
+                                      const MVT ConstraintVT) const override;
+
   void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint,
                                     std::vector<SDValue> &Ops,
                                     SelectionDAG &DAG) const override;



More information about the llvm-commits mailing list