[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 09:50:00 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Dinar Temirbulatov (dtemirbulatov)

<details>
<summary>Changes</summary>

…ctors one element types.

Proposed change fixes error while lowering masked load/store integer scalable vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.

---
Full diff: https://github.com/llvm/llvm-project/pull/99354.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+42-20) 
- (added) llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll (+142) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1a575abbc16f4..bb210fcd7bdab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5749,18 +5749,27 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
 
 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
 
-  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+  EVT VT = N->getValueType(0);
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue PassThru = GetWidenedVector(N->getPassThru());
   ISD::LoadExtType ExtType = N->getExtensionType();
   SDLoc dl(N);
 
-  // The mask should be widened as well
-  EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                    MaskVT.getVectorElementType(),
-                                    WidenVT.getVectorNumElements());
-  Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+      VT == MVT::nxv1i64) {
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorMinNumElements(), true);
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  } else {
+    // The mask should be widened as well
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorNumElements());
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  }
 
   SDValue Res = DAG.getMaskedLoad(
       WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6914,30 +6923,43 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue StVal = MST->getValue();
   SDLoc dl(N);
 
-  if (OpNo == 1) {
-    // Widen the value.
-    StVal = GetWidenedVector(StVal);
+  EVT VT = StVal.getValueType();
 
-    // The mask should be widened as well.
-    EVT WideVT = StVal.getValueType();
-    EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                      MaskVT.getVectorElementType(),
-                                      WideVT.getVectorNumElements());
-    Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (OpNo == 1) {
+    if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+        VT == MVT::nxv1i64) {
+      EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WidenVT.getVectorMinNumElements(), true);
+      StVal = ModifyToType(StVal, WidenVT);
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    } else {
+      // Widen the value.
+      StVal = GetWidenedVector(StVal);
+
+      // The mask should be widened as well.
+      EVT WideVT = StVal.getValueType();
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WideVT.getVectorNumElements());
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    }
   } else {
     // Widen the mask.
     EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
     Mask = ModifyToType(Mask, WideMaskVT, true);
 
-    EVT ValueVT = StVal.getValueType();
-    EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
-                                  ValueVT.getVectorElementType(),
+    EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
                                   WideMaskVT.getVectorNumElements());
     StVal = ModifyToType(StVal, WideVT);
   }
 
-  assert(Mask.getValueType().getVectorNumElements() ==
-         StVal.getValueType().getVectorNumElements() &&
+  assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+                                : Mask.getValueType().getVectorNumElements()) ==
+             (VT.isScalableVector()
+                  ? StVal.getValueType().getVectorMinNumElements()
+                  : StVal.getValueType().getVectorNumElements()) &&
          "Mask and data vectors should have the same number of elements");
   return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
                             MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
new file mode 100644
index 0000000000000..d7820273a6f78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s --check-prefix CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @store_i8(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i8> %val) #0 {
+; CHECK-LABEL: store_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i16(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i16> %val) #0 {
+; CHECK-LABEL: store_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    st1h { z0.h }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i32(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i32> %val) #0 {
+; CHECK-LABEL: store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i64(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i64> %val) #0 {
+; CHECK-LABEL: store_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    st1d { z0.d }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @load_store_i32(<vscale x 1 x i1> %pred, ptr %x, ptr %y) #0 {
+; CHECK-LABEL: load_store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x1]
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 2, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %load, ptr %y, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define <vscale x 1 x i8> @load_i8(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i8> undef)
+  ret <vscale x 1 x i8> %ret
+}
+
+define <vscale x 1 x i16> @load_i16(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i16> undef)
+  ret <vscale x 1 x i16> %ret
+}
+
+define <vscale x 1 x i32> @load_i32(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  ret <vscale x 1 x i32> %ret
+}
+
+define <vscale x 1 x i64> @load_i64(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i64> undef)
+  ret <vscale x 1 x i64> %ret
+}
+
+declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
+declare <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i16>)
+declare <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i32>)
+declare <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64>, ptr, i32 immarg, <vscale x 1 x i1>)
+
+attributes #0 = { "target-features"="+sve" }

``````````

</details>


https://github.com/llvm/llvm-project/pull/99354


More information about the llvm-commits mailing list