[llvm] 73e3866 - [AArch64][SME] Promote mask for masked load to a similar type size with load value.

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 30 01:55:43 PDT 2023


Author: Dinar Temirbulatov
Date: 2023-08-30T08:54:46Z
New Revision: 73e3866acbfd132d41e54776ac6b02f86c23a9e5

URL: https://github.com/llvm/llvm-project/commit/73e3866acbfd132d41e54776ac6b02f86c23a9e5
DIFF: https://github.com/llvm/llvm-project/commit/73e3866acbfd132d41e54776ac6b02f86c23a9e5.diff

LOG: [AArch64][SME] Promote mask for masked load to a similar type size with load value.

The legalizer could keep an original mask type of masked load combined with
sign/zero extend, but we have to extend the mask to a type similar to our
combined load otherwise instruction selection could not lower the load.

Differential Revision: https://reviews.llvm.org/D158386

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 547f47d835921e..145e5c8173bb29 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24971,7 +24971,15 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE(
   EVT VT = Op.getValueType();
   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
 
-  SDValue Mask = convertFixedMaskToScalableVector(Load->getMask(), DAG);
+  SDValue Mask = Load->getMask();
+  // If this is an extending load and the mask type is not the same as
+  // load's type then we have to extend the mask type.
+  if (VT.getScalarSizeInBits() > Mask.getValueType().getScalarSizeInBits()) {
+    assert(Load->getExtensionType() != ISD::NON_EXTLOAD &&
+           "Incorrect mask type");
+    Mask = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Mask);
+  }
+  Mask = convertFixedMaskToScalableVector(Mask, DAG);
 
   SDValue PassThru;
   bool IsPassThruZeroOrUndef = false;

diff  --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll
index 9785d795744ef8..ae0f84328276f2 100644
--- a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll
@@ -335,6 +335,58 @@ define <4 x double> @masked_load_v4f64(ptr %src, <4 x i1> %mask) {
   ret <4 x double> %load
 }
 
+define <3 x i32> @masked_load_zext_v3i32(ptr %load_ptr, <3 x i1> %pm) {
+; CHECK-LABEL: masked_load_zext_v3i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    strh w3, [sp, #12]
+; CHECK-NEXT:    adrp x8, .LCPI13_0
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    strh w2, [sp, #10]
+; CHECK-NEXT:    ldr d0, [x8, :lo12:.LCPI13_0]
+; CHECK-NEXT:    strh w1, [sp, #8]
+; CHECK-NEXT:    ldr d1, [sp, #8]
+; CHECK-NEXT:    and z0.d, z1.d, z0.d
+; CHECK-NEXT:    lsl z0.h, z0.h, #15
+; CHECK-NEXT:    asr z0.h, z0.h, #15
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    cmpne p0.s, p0/z, z0.s, #0
+; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %load_value = tail call <3 x i16> @llvm.masked.load.v3i16.p0(ptr %load_ptr, i32 4, <3 x i1> %pm, <3 x i16> zeroinitializer)
+  %extend = zext <3 x i16> %load_value to <3 x i32>
+  ret <3 x i32> %extend;
+}
+
+define <3 x i32> @masked_load_sext_v3i32(ptr %load_ptr, <3 x i1> %pm) {
+; CHECK-LABEL: masked_load_sext_v3i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    strh w3, [sp, #12]
+; CHECK-NEXT:    adrp x8, .LCPI14_0
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    strh w2, [sp, #10]
+; CHECK-NEXT:    ldr d0, [x8, :lo12:.LCPI14_0]
+; CHECK-NEXT:    strh w1, [sp, #8]
+; CHECK-NEXT:    ldr d1, [sp, #8]
+; CHECK-NEXT:    and z0.d, z1.d, z0.d
+; CHECK-NEXT:    lsl z0.h, z0.h, #15
+; CHECK-NEXT:    asr z0.h, z0.h, #15
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    cmpne p0.s, p0/z, z0.s, #0
+; CHECK-NEXT:    ld1sh { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %load_value = tail call <3 x i16> @llvm.masked.load.v3i16.p0(ptr %load_ptr, i32 4, <3 x i1> %pm, <3 x i16> zeroinitializer)
+  %extend = sext <3 x i16> %load_value to <3 x i32>
+  ret <3 x i32> %extend;
+}
+
 declare <4 x i8> @llvm.masked.load.v4i8(ptr, i32, <4 x i1>, <4 x i8>)
 declare <8 x i8> @llvm.masked.load.v8i8(ptr, i32, <8 x i1>, <8 x i8>)
 declare <16 x i8> @llvm.masked.load.v16i8(ptr, i32, <16 x i1>, <16 x i8>)
@@ -351,3 +403,5 @@ declare <8 x float> @llvm.masked.load.v8f32(ptr, i32, <8 x i1>, <8 x float>)
 
 declare <2 x double> @llvm.masked.load.v2f64(ptr, i32, <2 x i1>, <2 x double>)
 declare <4 x double> @llvm.masked.load.v4f64(ptr, i32, <4 x i1>, <4 x double>)
+
+declare <3 x i16> @llvm.masked.load.v3i16.p0(ptr, i32, <3 x i1>, <3 x i16>)


        


More information about the llvm-commits mailing list