[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 19 07:07:21 PDT 2024


https://github.com/dtemirbulatov updated https://github.com/llvm/llvm-project/pull/99354

>From ec8a0578e27368b24151e2f7efa7e63d2e2985e1 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Wed, 17 Jul 2024 16:02:41 +0000
Subject: [PATCH 1/3] [AArch64] Fix lowring error for masked load/store integer
 scalable vectors one element types.

Proposed change fixes error while lowering masked load/store integer scalable
vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.
---
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  62 +++++---
 .../CodeGen/AArch64/sve-nxv1-load-store.ll    | 142 ++++++++++++++++++
 2 files changed, 184 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1a575abbc16f4..bb210fcd7bdab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5749,18 +5749,27 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
 
 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
 
-  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+  EVT VT = N->getValueType(0);
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue PassThru = GetWidenedVector(N->getPassThru());
   ISD::LoadExtType ExtType = N->getExtensionType();
   SDLoc dl(N);
 
-  // The mask should be widened as well
-  EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                    MaskVT.getVectorElementType(),
-                                    WidenVT.getVectorNumElements());
-  Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+      VT == MVT::nxv1i64) {
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorMinNumElements(), true);
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  } else {
+    // The mask should be widened as well
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WidenVT.getVectorNumElements());
+    Mask = ModifyToType(Mask, WideMaskVT, true);
+  }
 
   SDValue Res = DAG.getMaskedLoad(
       WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6914,30 +6923,43 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue StVal = MST->getValue();
   SDLoc dl(N);
 
-  if (OpNo == 1) {
-    // Widen the value.
-    StVal = GetWidenedVector(StVal);
+  EVT VT = StVal.getValueType();
 
-    // The mask should be widened as well.
-    EVT WideVT = StVal.getValueType();
-    EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                      MaskVT.getVectorElementType(),
-                                      WideVT.getVectorNumElements());
-    Mask = ModifyToType(Mask, WideMaskVT, true);
+  if (OpNo == 1) {
+    if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+        VT == MVT::nxv1i64) {
+      EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WidenVT.getVectorMinNumElements(), true);
+      StVal = ModifyToType(StVal, WidenVT);
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    } else {
+      // Widen the value.
+      StVal = GetWidenedVector(StVal);
+
+      // The mask should be widened as well.
+      EVT WideVT = StVal.getValueType();
+      EVT WideMaskVT =
+          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                           WideVT.getVectorNumElements());
+      Mask = ModifyToType(Mask, WideMaskVT, true);
+    }
   } else {
     // Widen the mask.
     EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
     Mask = ModifyToType(Mask, WideMaskVT, true);
 
-    EVT ValueVT = StVal.getValueType();
-    EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
-                                  ValueVT.getVectorElementType(),
+    EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
                                   WideMaskVT.getVectorNumElements());
     StVal = ModifyToType(StVal, WideVT);
   }
 
-  assert(Mask.getValueType().getVectorNumElements() ==
-         StVal.getValueType().getVectorNumElements() &&
+  assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+                                : Mask.getValueType().getVectorNumElements()) ==
+             (VT.isScalableVector()
+                  ? StVal.getValueType().getVectorMinNumElements()
+                  : StVal.getValueType().getVectorNumElements()) &&
          "Mask and data vectors should have the same number of elements");
   return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
                             MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
new file mode 100644
index 0000000000000..d7820273a6f78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s --check-prefix CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @store_i8(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i8> %val) #0 {
+; CHECK-LABEL: store_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i16(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i16> %val) #0 {
+; CHECK-LABEL: store_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    st1h { z0.h }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i32(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i32> %val) #0 {
+; CHECK-LABEL: store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @store_i64(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i64> %val) #0 {
+; CHECK-LABEL: store_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    st1d { z0.d }, p0, [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define void @load_store_i32(<vscale x 1 x i1> %pred, ptr %x, ptr %y) #0 {
+; CHECK-LABEL: load_store_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x1]
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 2, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %load, ptr %y, i32 1, <vscale x 1 x i1> %pred)
+  ret void
+}
+
+define <vscale x 1 x i8> @load_i8(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i8> undef)
+  ret <vscale x 1 x i8> %ret
+}
+
+define <vscale x 1 x i16> @load_i16(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i16> undef)
+  ret <vscale x 1 x i16> %ret
+}
+
+define <vscale x 1 x i32> @load_i32(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+  ret <vscale x 1 x i32> %ret
+}
+
+define <vscale x 1 x i64> @load_i64(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT:    ret
+  %ret = call <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i64> undef)
+  ret <vscale x 1 x i64> %ret
+}
+
+declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
+declare <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i16>)
+declare <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i32>)
+declare <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64>, ptr, i32 immarg, <vscale x 1 x i1>)
+
+attributes #0 = { "target-features"="+sve" }

>From d453816edc09d0f0274e1e88e3fb4edaf241bf14 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Fri, 19 Jul 2024 13:37:12 +0000
Subject: [PATCH 2/3] Resolved comments.

---
 .gitreview                                    |  3 ++
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 51 +++++++------------
 2 files changed, 21 insertions(+), 33 deletions(-)
 create mode 100644 .gitreview

diff --git a/.gitreview b/.gitreview
new file mode 100644
index 0000000000000..9b4839c770879
--- /dev/null
+++ b/.gitreview
@@ -0,0 +1,3 @@
+[gerrit]
+defaultremote=origin
+defaultbranch=main
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bb210fcd7bdab..d75afee9b652e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5748,28 +5748,18 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
-
-  EVT VT = N->getValueType(0);
-  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue PassThru = GetWidenedVector(N->getPassThru());
   ISD::LoadExtType ExtType = N->getExtensionType();
   SDLoc dl(N);
 
-  if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
-      VT == MVT::nxv1i64) {
-    EVT WideMaskVT =
-        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
-                         WidenVT.getVectorMinNumElements(), true);
-    Mask = ModifyToType(Mask, WideMaskVT, true);
-  } else {
-    // The mask should be widened as well
-    EVT WideMaskVT =
-        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
-                         WidenVT.getVectorNumElements());
-    Mask = ModifyToType(Mask, WideMaskVT, true);
-  }
+  // The mask should be widened as well
+  EVT WideMaskVT =
+      EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                       WidenVT.getVectorElementCount());
+  Mask = ModifyToType(Mask, WideMaskVT, true);
 
   SDValue Res = DAG.getMaskedLoad(
       WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6921,30 +6911,25 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue Mask = MST->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue StVal = MST->getValue();
-  SDLoc dl(N);
-
   EVT VT = StVal.getValueType();
+  SDLoc dl(N);
 
   if (OpNo == 1) {
-    if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
-        VT == MVT::nxv1i64) {
-      EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
-      EVT WideMaskVT =
-          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
-                           WidenVT.getVectorMinNumElements(), true);
-      StVal = ModifyToType(StVal, WidenVT);
-      Mask = ModifyToType(Mask, WideMaskVT, true);
+    EVT WideVT;
+    if (VT.isScalableVector() && VT.getVectorMinNumElements() == 1 &&
+        VT.isInteger() && VT.getVectorElementType().isByteSized()) {
+      WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+      StVal = ModifyToType(StVal, WideVT);
     } else {
       // Widen the value.
       StVal = GetWidenedVector(StVal);
-
-      // The mask should be widened as well.
-      EVT WideVT = StVal.getValueType();
-      EVT WideMaskVT =
-          EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
-                           WideVT.getVectorNumElements());
-      Mask = ModifyToType(Mask, WideMaskVT, true);
+      // The mask should be widened as well
+      WideVT = StVal.getValueType();
     }
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WideVT.getVectorElementCount());
+    Mask = ModifyToType(Mask, WideMaskVT, true);
   } else {
     // Widen the mask.
     EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);

>From 3f0cabaae333aeb14adeed3b968d50d176f73cc9 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Fri, 19 Jul 2024 14:02:38 +0000
Subject: [PATCH 3/3] Removed extra file.

---
 .gitreview | 3 ---
 1 file changed, 3 deletions(-)
 delete mode 100644 .gitreview

diff --git a/.gitreview b/.gitreview
deleted file mode 100644
index 9b4839c770879..0000000000000
--- a/.gitreview
+++ /dev/null
@@ -1,3 +0,0 @@
-[gerrit]
-defaultremote=origin
-defaultbranch=main



More information about the llvm-commits mailing list