[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 09:49:29 PDT 2024


https://github.com/dtemirbulatov created https://github.com/llvm/llvm-project/pull/99354

…ctors one element types.

Proposed change fixes error while lowering masked load/store integer scalable vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.

>From ec8a0578e27368b24151e2f7efa7e63d2e2985e1 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Wed, 17 Jul 2024 16:02:41 +0000
Subject: [PATCH] [AArch64] Fix lowring error for masked load/store integer
 scalable vectors one element types.

Proposed change fixes error while lowering masked load/store integer scalable
vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.
---
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  62 +++++---
 .../CodeGen/AArch64/sve-nxv1-load-store.ll    | 142 ++++++++++++++++++
 2 files changed, 184 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1a575abbc16f4..bb210fcd7bdab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5749,18 +5749,27 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
 
 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
 
-  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+  EVT VT = N->getValueType(0);
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue PassThru = GetWidenedVector(N->getPassThru());
   ISD::LoadExtType ExtType = N->getExtensionType();
   SDLoc dl(N);
 
-  // The mask should be widened as well
-  EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                    MaskVT.getVectorElementType(),
-                                    WidenVT.getVectorNumElements());
-  Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+      VT == MVT::nxv1i64) {
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorMinNumElements(), true);
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  } else {
+    // The mask should be widened as well
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorNumElements());
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  }
 
   SDValue Res = DAG.getMaskedLoad(
       WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6914,30 +6923,43 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue StVal = MST->getValue();
   SDLoc dl(N);
 
-  if (OpNo == 1) {
-    // Widen the value.
-    StVal = GetWidenedVector(StVal);
+  EVT VT = StVal.getValueType();
 
-    // The mask should be widened as well.
-    EVT WideVT = StVal.getValueType();
-    EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                      MaskVT.getVectorElementType(),
-                                      WideVT.getVectorNumElements());
-    Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (OpNo == 1) {
+    if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+        VT == MVT::nxv1i64) {
+      EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WidenVT.getVectorMinNumElements(), true);
+      StVal = ModifyToType(StVal, WidenVT);
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    } else {
+      // Widen the value.
+      StVal = GetWidenedVector(StVal);
+
+      // The mask should be widened as well.
+      EVT WideVT = StVal.getValueType();
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WideVT.getVectorNumElements());
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    }
   } else {
     // Widen the mask.
     EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
     Mask = ModifyToType(Mask, WideMaskVT, true);
 
-    EVT ValueVT = StVal.getValueType();
-    EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
-                                  ValueVT.getVectorElementType(),
+    EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
                                   WideMaskVT.getVectorNumElements());
     StVal = ModifyToType(StVal, WideVT);
   }
 
-  assert(Mask.getValueType().getVectorNumElements() ==
-         StVal.getValueType().getVectorNumElements() &&
+  assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+                                : Mask.getValueType().getVectorNumElements()) ==
+             (VT.isScalableVector()
+                  ? StVal.getValueType().getVectorMinNumElements()
+                  : StVal.getValueType().getVectorNumElements()) &&
          "Mask and data vectors should have the same number of elements");
   return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
                             MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
new file mode 100644
index 0000000000000..d7820273a6f78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s --check-prefix CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @store_i8(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i8> %val) #0 {
+; CHECK-LABEL: store_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i16(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i16> %val) #0 {
+; CHECK-LABEL: store_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    st1h { z0.h }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i32(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i32> %val) #0 {
+; CHECK-LABEL: store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i64(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i64> %val) #0 {
+; CHECK-LABEL: store_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    st1d { z0.d }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @load_store_i32(<vscale x 1 x i1> %pred, ptr %x, ptr %y) #0 {
+; CHECK-LABEL: load_store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x1]
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 2, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %load, ptr %y, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define <vscale x 1 x i8> @load_i8(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i8> undef)
+  ret <vscale x 1 x i8> %ret
+}
+
+define <vscale x 1 x i16> @load_i16(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i16> undef)
+  ret <vscale x 1 x i16> %ret
+}
+
+define <vscale x 1 x i32> @load_i32(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  ret <vscale x 1 x i32> %ret
+}
+
+define <vscale x 1 x i64> @load_i64(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i64> undef)
+  ret <vscale x 1 x i64> %ret
+}
+
+declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
+declare <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i16>)
+declare <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i32>)
+declare <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64>, ptr, i32 immarg, <vscale x 1 x i1>)
+
+attributes #0 = { "target-features"="+sve" }



More information about the llvm-commits mailing list