[llvm] [X86] Enable bfloat type support in inline assembly constraints (PR #68469)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 9 04:51:21 PDT 2023


https://github.com/phoebewang updated https://github.com/llvm/llvm-project/pull/68469

>From 0443c0860e98aef51125d5767d321f8a7b7c2106 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Sat, 7 Oct 2023 14:29:23 +0800
Subject: [PATCH 1/2] [X86] Enable bfloat type support in inline assembly
 constraints

Similar to FP16 but we don't have native scalar instruction support, so
limit it to vector types only.

Fixes #68149
---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 24 +++++++++++++++++++
 .../X86/inline-asm-avx512f-x-constraint.ll    | 13 +++++++++-
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c4cd2a672fe7b26..c0e93da877a8a10 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56904,6 +56904,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v8bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::f128:
       case MVT::v16i8:
       case MVT::v8i16:
@@ -56919,6 +56923,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v16bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::v32i8:
       case MVT::v16i16:
       case MVT::v8i32:
@@ -56934,6 +56942,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v32bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::v64i8:
       case MVT::v32i16:
       case MVT::v8f64:
@@ -56977,6 +56989,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v8bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::f128:
       case MVT::v16i8:
       case MVT::v8i16:
@@ -56990,6 +57006,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v16bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::v32i8:
       case MVT::v16i16:
       case MVT::v8i32:
@@ -57003,6 +57023,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         if (!Subtarget.hasFP16())
           break;
         [[fallthrough]];
+      case MVT::v32bf16:
+        if (!Subtarget.hasBF16())
+          break;
+        [[fallthrough]];
       case MVT::v64i8:
       case MVT::v32i16:
       case MVT::v8f64:
diff --git a/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll b/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll
index fcea55c47cd3ec4..e153387d16e72b1 100644
--- a/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll
+++ b/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll
@@ -1,7 +1,7 @@
 ; RUN: not llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512f -stop-after=finalize-isel > %t 2> %t.err
 ; RUN: FileCheck < %t %s
 ; RUN: FileCheck --check-prefix=CHECK-STDERR < %t.err %s
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512fp16 -stop-after=finalize-isel | FileCheck --check-prefixes=CHECK,FP16 %s
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512bf16,avx512fp16 -stop-after=finalize-isel | FileCheck --check-prefixes=CHECK,FP16 %s
 
 ; CHECK-LABEL: name: mask_Yk_i8
 ; CHECK: %[[REG1:.*]]:vr512_0_15 = COPY %1
@@ -24,3 +24,14 @@ entry:
   %0 = tail call <32 x half> asm "vaddph\09$3, $2, $0 {$1}", "=x,^Yk,x,x,~{dirflag},~{fpsr},~{flags}"(i8 %msk, <32 x half> %x, <32 x half> %y)
   ret <32 x half> %0
 }
+
+; FP16-LABEL: name: mask_Yk_bf16
+; FP16: %[[REG1:.*]]:vr512_0_15 = COPY %1
+; FP16: %[[REG2:.*]]:vr512_0_15 = COPY %2
+; FP16: INLINEASM &"vaddph\09$3, $2, $0 {$1}", 0 /* attdialect */, {{.*}}, def %{{.*}}, {{.*}}, %{{.*}}, {{.*}}, %[[REG1]], {{.*}}, %[[REG2]], 12 /* clobber */, implicit-def early-clobber $df, 12 /* clobber */, implicit-def early-clobber $fpsw, 12 /* clobber */, implicit-def early-clobber $eflags
+; CHECK-STDERR: couldn't allocate output register for constraint 'x'
+define <32 x bfloat> @mask_Yk_bf16(i8 signext %msk, <32 x bfloat> %x, <32 x bfloat> %y) {
+entry:
+  %0 = tail call <32 x bfloat> asm "vaddph\09$3, $2, $0 {$1}", "=x,^Yk,x,x,~{dirflag},~{fpsr},~{flags}"(i8 %msk, <32 x bfloat> %x, <32 x bfloat> %y)
+  ret <32 x bfloat> %0
+}

>From ee98b7dd9df4ab6d1afe187bc900cdb7f8ca5460 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Mon, 9 Oct 2023 19:50:50 +0800
Subject: [PATCH 2/2] Do not use [[fallthrough]]

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 36 ++++++++++++++++---------
 1 file changed, 24 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c0e93da877a8a10..6a9f39ada651ca2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56903,11 +56903,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v8f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR128XRegClass);
+        return std::make_pair(0U, &X86::VR128RegClass);
       case MVT::v8bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR128XRegClass);
+        return std::make_pair(0U, &X86::VR128RegClass);
       case MVT::f128:
       case MVT::v16i8:
       case MVT::v8i16:
@@ -56922,11 +56926,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v16f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR256XRegClass);
+        return std::make_pair(0U, &X86::VR256RegClass);
       case MVT::v16bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR256XRegClass);
+        return std::make_pair(0U, &X86::VR256RegClass);
       case MVT::v32i8:
       case MVT::v16i16:
       case MVT::v8i32:
@@ -56941,11 +56949,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v32f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR512RegClass);
+        return std::make_pair(0U, &X86::VR512_0_15RegClass);
       case MVT::v32bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        if (VConstraint)
+          return std::make_pair(0U, &X86::VR512RegClass);
+        return std::make_pair(0U, &X86::VR512_0_15RegClass);
       case MVT::v64i8:
       case MVT::v32i16:
       case MVT::v8f64:
@@ -56988,11 +57000,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v8f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::XMM0, &X86::VR128RegClass);
       case MVT::v8bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::XMM0, &X86::VR128RegClass);
       case MVT::f128:
       case MVT::v16i8:
       case MVT::v8i16:
@@ -57005,11 +57017,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v16f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::YMM0, &X86::VR256RegClass);
       case MVT::v16bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::YMM0, &X86::VR256RegClass);
       case MVT::v32i8:
       case MVT::v16i16:
       case MVT::v8i32:
@@ -57022,11 +57034,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       case MVT::v32f16:
         if (!Subtarget.hasFP16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass);
       case MVT::v32bf16:
         if (!Subtarget.hasBF16())
           break;
-        [[fallthrough]];
+        return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass);
       case MVT::v64i8:
       case MVT::v32i16:
       case MVT::v8f64:



More information about the llvm-commits mailing list