[llvm] 9cc7eb0 - [SVE] Remove redundant hasBF16 calls from lowering code.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 10 06:52:40 PST 2022


Author: Paul Walker
Date: 2022-02-10T14:47:23Z
New Revision: 9cc7eb0ec92dce3a7731a31c106816e0453977ce

URL: https://github.com/llvm/llvm-project/commit/9cc7eb0ec92dce3a7731a31c106816e0453977ce
DIFF: https://github.com/llvm/llvm-project/commit/9cc7eb0ec92dce3a7731a31c106816e0453977ce.diff

LOG: [SVE] Remove redundant hasBF16 calls from lowering code.

The are several places where hasBF16 is used to protect code that
has no requirement for the +bf16 feature.  The lowering code uses
stock SVE instructions for things like loads and stores and so is
safe even when +bf16 is not available.

NOTE: Currently the nxvbf16 type is not legal unless the +bf16
feature is available, but that isn't an issue because the affected
code is post type legalisation.

NOTE: This patch mirrors previous work that removed the same
redundant protection from isel patterns where the resulting
selection emitted stock SVE instructions.

Differential Revision: https://reviews.llvm.org/D119328

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 899f069abdd4..45b2937404ac 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -3901,7 +3901,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
                              true);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedLoad(Node, 2, 1, AArch64::LD2H_IMM, AArch64::LD2H,
                              true);
         return;
@@ -3922,7 +3922,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
                              true);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedLoad(Node, 3, 1, AArch64::LD3H_IMM, AArch64::LD3H,
                              true);
         return;
@@ -3943,7 +3943,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
                              true);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedLoad(Node, 4, 1, AArch64::LD4H_IMM, AArch64::LD4H,
                              true);
         return;
@@ -4267,7 +4267,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
         SelectPredicatedStore(Node, 2, 0, AArch64::ST2B, AArch64::ST2B_IMM);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedStore(Node, 2, 1, AArch64::ST2H, AArch64::ST2H_IMM);
         return;
       } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
@@ -4284,7 +4284,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
         SelectPredicatedStore(Node, 3, 0, AArch64::ST3B, AArch64::ST3B_IMM);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedStore(Node, 3, 1, AArch64::ST3H, AArch64::ST3H_IMM);
         return;
       } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
@@ -4301,7 +4301,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
         SelectPredicatedStore(Node, 4, 0, AArch64::ST4B, AArch64::ST4B_IMM);
         return;
       } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-                 (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+                 VT == MVT::nxv8bf16) {
         SelectPredicatedStore(Node, 4, 1, AArch64::ST4H, AArch64::ST4H_IMM);
         return;
       } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
@@ -4911,7 +4911,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
       SelectPredicatedLoad(Node, 2, 0, AArch64::LD2B_IMM, AArch64::LD2B);
       return;
     } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-               (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+               VT == MVT::nxv8bf16) {
       SelectPredicatedLoad(Node, 2, 1, AArch64::LD2H_IMM, AArch64::LD2H);
       return;
     } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
@@ -4928,7 +4928,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
       SelectPredicatedLoad(Node, 3, 0, AArch64::LD3B_IMM, AArch64::LD3B);
       return;
     } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-               (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+               VT == MVT::nxv8bf16) {
       SelectPredicatedLoad(Node, 3, 1, AArch64::LD3H_IMM, AArch64::LD3H);
       return;
     } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
@@ -4945,7 +4945,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
       SelectPredicatedLoad(Node, 4, 0, AArch64::LD4B_IMM, AArch64::LD4B);
       return;
     } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
-               (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
+               VT == MVT::nxv8bf16) {
       SelectPredicatedLoad(Node, 4, 1, AArch64::LD4H_IMM, AArch64::LD4H);
       return;
     } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 41c042cac69a..5dbdba448af1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4610,10 +4610,6 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   EVT MemVT = MGT->getMemoryVT();
   SDValue InputVT = DAG.getValueType(MemVT);
 
-  if (VT.getVectorElementType() == MVT::bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
-
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");
@@ -4715,10 +4711,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   EVT MemVT = MSC->getMemoryVT();
   SDValue InputVT = DAG.getValueType(MemVT);
 
-  if (VT.getVectorElementType() == MVT::bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
-
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");
@@ -15801,10 +15793,6 @@ static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   EVT PtrTy = N->getOperand(3).getValueType();
 
-  if (VT == MVT::nxv8bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
-
   EVT LoadVT = VT;
   if (VT.isFloatingPoint())
     LoadVT = VT.changeTypeToInteger();
@@ -15832,9 +15820,6 @@ static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) {
                 "Unsupported opcode.");
   SDLoc DL(N);
   EVT VT = N->getValueType(0);
-  if (VT == MVT::nxv8bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
 
   EVT LoadVT = VT;
   if (VT.isFloatingPoint())
@@ -15857,10 +15842,6 @@ static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
   EVT HwSrcVt = getSVEContainerType(DataVT);
   SDValue InputVT = DAG.getValueType(DataVT);
 
-  if (DataVT == MVT::nxv8bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
-
   if (DataVT.isFloatingPoint())
     InputVT = DAG.getValueType(HwSrcVt);
 
@@ -15887,10 +15868,6 @@ static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
   EVT DataVT = Data.getValueType();
   EVT PtrTy = N->getOperand(4).getValueType();
 
-  if (DataVT == MVT::nxv8bf16 &&
-      !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
-    return SDValue();
-
   if (DataVT.isFloatingPoint())
     Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data);
 


        


More information about the llvm-commits mailing list