[llvm] 80c534a - [GlobalISel][CallLowering] Fix crash when handling a v3s32 type that's being passed as v2s64.

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Fri May 14 16:30:57 PDT 2021


Author: Amara Emerson
Date: 2021-05-14T16:30:51-07:00
New Revision: 80c534a8f97fef050ebbe3411413018abd2ca2ae

URL: https://github.com/llvm/llvm-project/commit/80c534a8f97fef050ebbe3411413018abd2ca2ae
DIFF: https://github.com/llvm/llvm-project/commit/80c534a8f97fef050ebbe3411413018abd2ca2ae.diff

LOG: [GlobalISel][CallLowering] Fix crash when handling a v3s32 type that's being passed as v2s64.

Added: 
    

Modified: 
    llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
    llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-vectors.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index d0db5e2ee31fb..e460fd8950c26 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -328,19 +328,23 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
     return;
   }
 
+  // A vector PartLLT needs extending to LLTy's element size.
+  // E.g. <2 x s64> = G_SEXT <2 x s32>.
   if (PartLLT.isVector() == LLTy.isVector() &&
       PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
+      (!PartLLT.isVector() ||
+       PartLLT.getNumElements() == LLTy.getNumElements()) &&
       OrigRegs.size() == 1 && Regs.size() == 1) {
     Register SrcReg = Regs[0];
 
     LLT LocTy = MRI.getType(SrcReg);
 
     if (Flags.isSExt()) {
-      SrcReg = B.buildAssertSExt(LocTy, SrcReg,
-                                 LLTy.getScalarSizeInBits()).getReg(0);
+      SrcReg = B.buildAssertSExt(LocTy, SrcReg, LLTy.getScalarSizeInBits())
+                   .getReg(0);
     } else if (Flags.isZExt()) {
-      SrcReg = B.buildAssertZExt(LocTy, SrcReg,
-                                 LLTy.getScalarSizeInBits()).getReg(0);
+      SrcReg = B.buildAssertZExt(LocTy, SrcReg, LLTy.getScalarSizeInBits())
+                   .getReg(0);
     }
 
     B.buildTrunc(OrigRegs[0], SrcReg);
@@ -364,18 +368,30 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
 
   if (PartLLT.isVector()) {
     assert(OrigRegs.size() == 1);
+    SmallVector<Register> CastRegs(Regs.begin(), Regs.end());
+
+    // If PartLLT is a mismatched vector in both number of elements and element
+    // size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
+    // have the same elt type, i.e. v4s32.
+    if (PartLLT.getSizeInBits() > LLTy.getSizeInBits() &&
+        PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
+        Regs.size() == 1) {
+      LLT NewTy = PartLLT.changeElementType(LLTy.getElementType())
+                      .changeNumElements(PartLLT.getNumElements() * 2);
+      CastRegs[0] = B.buildBitcast(NewTy, Regs[0]).getReg(0);
+      PartLLT = NewTy;
+    }
 
     if (LLTy.getScalarType() == PartLLT.getElementType()) {
-      mergeVectorRegsToResultRegs(B, OrigRegs, Regs);
+      mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs);
     } else {
-      SmallVector<Register> CastRegs(Regs.size());
       unsigned I = 0;
       LLT GCDTy = getGCDType(LLTy, PartLLT);
 
       // We are both splitting a vector, and bitcasting its element types. Cast
       // the source pieces into the appropriate number of pieces with the result
       // element type.
-      for (Register SrcReg : Regs)
+      for (Register SrcReg : CastRegs)
         CastRegs[I++] = B.buildBitcast(GCDTy, SrcReg).getReg(0);
       mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs);
     }

diff  --git a/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-vectors.ll b/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-vectors.ll
index f34f0981c2116..bee323ad69d4b 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-vectors.ll
@@ -44,3 +44,22 @@ define <1 x half> @test_v1s16(<1 x float> %x) {
   %tmp = fptrunc <1 x float> %x to <1 x half>
   ret <1 x half> %tmp
 }
+
+declare <3 x float> @bar(float)
+define void @test_return_v3f32() {
+  ; CHECK-LABEL: name: test_return_v3f32
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK:   [[DEF:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF
+  ; CHECK:   ADJCALLSTACKDOWN 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK:   $s0 = COPY [[DEF]](s32)
+  ; CHECK:   BL @bar, csr_aarch64_aapcs, implicit-def $lr, implicit $sp, implicit $s0, implicit-def $q0
+  ; CHECK:   [[COPY:%[0-9]+]]:_(<2 x s64>) = COPY $q0
+  ; CHECK:   [[BITCAST:%[0-9]+]]:_(<4 x s32>) = G_BITCAST [[COPY]](<2 x s64>)
+  ; CHECK:   [[DEF1:%[0-9]+]]:_(<4 x s32>) = G_IMPLICIT_DEF
+  ; CHECK:   [[CONCAT_VECTORS:%[0-9]+]]:_(<12 x s32>) = G_CONCAT_VECTORS [[BITCAST]](<4 x s32>), [[DEF1]](<4 x s32>), [[DEF1]](<4 x s32>)
+  ; CHECK:   [[UV:%[0-9]+]]:_(<3 x s32>), [[UV1:%[0-9]+]]:_(<3 x s32>), [[UV2:%[0-9]+]]:_(<3 x s32>), [[UV3:%[0-9]+]]:_(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>)
+  ; CHECK:   ADJCALLSTACKUP 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK:   RET_ReallyLR
+  %call = call <3 x float> @bar(float undef)
+  ret void
+}


        


More information about the llvm-commits mailing list