[llvm] [LLVM][RISCV][BF16] Support Inline ASM for the bf16 type. (PR #80118)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 31 00:47:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chuan-Yue Yuan (circYuan)

<details>
<summary>Changes</summary>

This patch makes the RISCV-V asm constraint `f` recognize the bfloat type.

---
Full diff: https://github.com/llvm/llvm-project/pull/80118.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+12-1) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+7) 
- (added) llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll (+77) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5ce1013f30fd1..7342e7bcba1f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -9105,7 +9105,7 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
   // Get the actual register value type.  This is important, because the user
   // may have asked for (e.g.) the AX register in i32 type.  We need to
   // remember that AX is actually i16 to get the right extension.
-  const MVT RegVT = *TRI.legalclasstypes_begin(*RC);
+  MVT RegVT = *TRI.legalclasstypes_begin(*RC);
 
   if (OpInfo.ConstraintVT != MVT::Other && RegVT != MVT::Untyped) {
     // If this is an FP operand in an integer register (or visa versa), or more
@@ -9139,6 +9139,17 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
               DAG.getNode(ISD::BITCAST, DL, VT, OpInfo.CallOperand);
         OpInfo.ConstraintVT = VT;
       }
+      // If the RegisterClass contains more than one types like RISCV
+      // FPR16RegClass which has [f16, bf16], We should check if the
+      // OpInfo.ConstraintVT can directly be assigned to the RegVT.
+    } else if ((OpInfo.Type == InlineAsm::isOutput ||
+                OpInfo.Type == InlineAsm::isInput) &&
+               TRI.isTypeLegalForClass(*RC, OpInfo.ConstraintVT)) {
+      if (RegVT != OpInfo.ConstraintVT &&
+          RegVT.getSizeInBits() == OpInfo.ConstraintVT.getSizeInBits() &&
+          RegVT.isFloatingPoint() && OpInfo.ConstraintVT.isFloatingPoint()) {
+        RegVT = OpInfo.ConstraintVT;
+      }
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b8994e7b7bdb2..7a6e41ab7fee3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19225,6 +19225,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         return std::make_pair(0U, &RISCV::GPRPairRegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
+      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16)
+        return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
         return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtF() && VT == MVT::f32)
@@ -19343,6 +19345,11 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);
       }
+      if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16) {
+        unsigned RegNo = FReg - RISCV::F0_F;
+        unsigned HReg = RISCV::F0_H + RegNo;
+        return std::make_pair(HReg, &RISCV::FPR16RegClass);
+      }
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
new file mode 100644
index 0000000000000..a496e2fea173e
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64F %s
+
+ at gf = external global float
+
+define float @constraint_f_float(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa5, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, fa5
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa5, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, fa5
+; RV64F-NEXT:    ret
+  %1 = load float, ptr @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "=f,f"(bfloat %a)
+  ret float %2
+}
+
+define float @constraint_f_float_abi_name(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float_abi_name:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    fmv.h.x fa0, a0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    fmv.x.w a0, ft0
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_f_float_abi_name:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    fmv.h.x fa0, a0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    fcvt.s.bf16 ft0, fa0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    fmv.x.w a0, ft0
+; RV64F-NEXT:    ret
+  %1 = load float, ptr @gf
+  %2 = tail call float asm "fcvt.s.bf16 $0, $1", "={ft0},{fa0}"(bfloat %a)
+  ret float %2
+}
+
+define bfloat @constraint_gpr(bfloat %x) {
+; RV32F-LABEL: constraint_gpr:
+; RV32F:       # %bb.0:
+; RV32F-NEXT:    .cfi_def_cfa_offset 0
+; RV32F-NEXT:    #APP
+; RV32F-NEXT:    mv a0, a0
+; RV32F-NEXT:    #NO_APP
+; RV32F-NEXT:    ret
+;
+; RV64F-LABEL: constraint_gpr:
+; RV64F:       # %bb.0:
+; RV64F-NEXT:    .cfi_def_cfa_offset 0
+; RV64F-NEXT:    #APP
+; RV64F-NEXT:    mv a0, a0
+; RV64F-NEXT:    #NO_APP
+; RV64F-NEXT:    ret
+  %1 = tail call bfloat asm sideeffect alignstack "mv $0, $1", "={x10},{x10}"(bfloat %x)
+  ret bfloat %1
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/80118


More information about the llvm-commits mailing list