[llvm] DAG: Handle equal size element build_vector promotion (PR #76213)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 22 00:36:09 PST 2023


https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/76213

This will allow promotion of bf16 to i16 vectors in a future patch.

>From a23e10b87a4c45705fa10d378fe665602a9ce979 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Wed, 20 Dec 2023 20:58:09 +0700
Subject: [PATCH] DAG: Handle equal size element build_vector promotion

This will allow promotion of bf16 to i16 vectors in a future
patch.
---
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..4e317062cec497 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4908,7 +4908,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
 static MVT getPromotedVectorElementType(const TargetLowering &TLI,
                                         MVT EltVT, MVT NewEltVT) {
   unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
-  MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+  MVT MidVT = OldEltsPerNewElt == 1
+                  ? NewEltVT
+                  : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
   assert(TLI.isTypeLegal(MidVT) && "unexpected");
   return MidVT;
 }
@@ -5395,7 +5397,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
 
     assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
            "Invalid promote type for build_vector");
-    assert(NewEltVT.bitsLT(EltVT) && "not handled");
+    assert(NewEltVT.bitsLE(EltVT) && "not handled");
 
     MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
 
@@ -5406,7 +5408,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     }
 
     SDLoc SL(Node);
-    SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+    SDValue Concat =
+        DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+                    SL, NVT, NewOps);
     SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
     Results.push_back(CvtVec);
     break;



More information about the llvm-commits mailing list