[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)
Dinar Temirbulatov via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 19 07:08:44 PDT 2024
https://github.com/dtemirbulatov updated https://github.com/llvm/llvm-project/pull/99354
>From ec8a0578e27368b24151e2f7efa7e63d2e2985e1 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Wed, 17 Jul 2024 16:02:41 +0000
Subject: [PATCH 1/4] [AArch64] Fix lowring error for masked load/store integer
scalable vectors one element types.
Proposed change fixes error while lowering masked load/store integer scalable
vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.
---
.../SelectionDAG/LegalizeVectorTypes.cpp | 62 +++++---
.../CodeGen/AArch64/sve-nxv1-load-store.ll | 142 ++++++++++++++++++
2 files changed, 184 insertions(+), 20 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1a575abbc16f4..bb210fcd7bdab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5749,18 +5749,27 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
- EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+ EVT VT = N->getValueType(0);
+ EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue Mask = N->getMask();
EVT MaskVT = Mask.getValueType();
SDValue PassThru = GetWidenedVector(N->getPassThru());
ISD::LoadExtType ExtType = N->getExtensionType();
SDLoc dl(N);
- // The mask should be widened as well
- EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
- MaskVT.getVectorElementType(),
- WidenVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+ VT == MVT::nxv1i64) {
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorMinNumElements(), true);
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ } else {
+ // The mask should be widened as well
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorNumElements());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ }
SDValue Res = DAG.getMaskedLoad(
WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6914,30 +6923,43 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
SDValue StVal = MST->getValue();
SDLoc dl(N);
- if (OpNo == 1) {
- // Widen the value.
- StVal = GetWidenedVector(StVal);
+ EVT VT = StVal.getValueType();
- // The mask should be widened as well.
- EVT WideVT = StVal.getValueType();
- EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
- MaskVT.getVectorElementType(),
- WideVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ if (OpNo == 1) {
+ if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+ VT == MVT::nxv1i64) {
+ EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorMinNumElements(), true);
+ StVal = ModifyToType(StVal, WidenVT);
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ } else {
+ // Widen the value.
+ StVal = GetWidenedVector(StVal);
+
+ // The mask should be widened as well.
+ EVT WideVT = StVal.getValueType();
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WideVT.getVectorNumElements());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ }
} else {
// Widen the mask.
EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
Mask = ModifyToType(Mask, WideMaskVT, true);
- EVT ValueVT = StVal.getValueType();
- EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
- ValueVT.getVectorElementType(),
+ EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
WideMaskVT.getVectorNumElements());
StVal = ModifyToType(StVal, WideVT);
}
- assert(Mask.getValueType().getVectorNumElements() ==
- StVal.getValueType().getVectorNumElements() &&
+ assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+ : Mask.getValueType().getVectorNumElements()) ==
+ (VT.isScalableVector()
+ ? StVal.getValueType().getVectorMinNumElements()
+ : StVal.getValueType().getVectorNumElements()) &&
"Mask and data vectors should have the same number of elements");
return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
new file mode 100644
index 0000000000000..d7820273a6f78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s --check-prefix CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @store_i8(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i8> %val) #0 {
+; CHECK-LABEL: store_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i16(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i16> %val) #0 {
+; CHECK-LABEL: store_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: st1h { z0.h }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i32(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i32> %val) #0 {
+; CHECK-LABEL: store_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i64(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i64> %val) #0 {
+; CHECK-LABEL: store_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @load_store_i32(<vscale x 1 x i1> %pred, ptr %x, ptr %y) #0 {
+; CHECK-LABEL: load_store_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT: st1w { z0.s }, p0, [x1]
+; CHECK-NEXT: ret
+ %load = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 2, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+ call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %load, ptr %y, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define <vscale x 1 x i8> @load_i8(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i8> undef)
+ ret <vscale x 1 x i8> %ret
+}
+
+define <vscale x 1 x i16> @load_i16(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i16> undef)
+ ret <vscale x 1 x i16> %ret
+}
+
+define <vscale x 1 x i32> @load_i32(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+ ret <vscale x 1 x i32> %ret
+}
+
+define <vscale x 1 x i64> @load_i64(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i64> undef)
+ ret <vscale x 1 x i64> %ret
+}
+
+declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
+declare <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i16>)
+declare <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i32>)
+declare <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64>, ptr, i32 immarg, <vscale x 1 x i1>)
+
+attributes #0 = { "target-features"="+sve" }
>From d453816edc09d0f0274e1e88e3fb4edaf241bf14 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Fri, 19 Jul 2024 13:37:12 +0000
Subject: [PATCH 2/4] Resolved comments.
---
.gitreview | 3 ++
.../SelectionDAG/LegalizeVectorTypes.cpp | 51 +++++++------------
2 files changed, 21 insertions(+), 33 deletions(-)
create mode 100644 .gitreview
diff --git a/.gitreview b/.gitreview
new file mode 100644
index 0000000000000..9b4839c770879
--- /dev/null
+++ b/.gitreview
@@ -0,0 +1,3 @@
+[gerrit]
+defaultremote=origin
+defaultbranch=main
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bb210fcd7bdab..d75afee9b652e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5748,28 +5748,18 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
}
SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
-
- EVT VT = N->getValueType(0);
- EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDValue Mask = N->getMask();
EVT MaskVT = Mask.getValueType();
SDValue PassThru = GetWidenedVector(N->getPassThru());
ISD::LoadExtType ExtType = N->getExtensionType();
SDLoc dl(N);
- if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
- VT == MVT::nxv1i64) {
- EVT WideMaskVT =
- EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
- WidenVT.getVectorMinNumElements(), true);
- Mask = ModifyToType(Mask, WideMaskVT, true);
- } else {
- // The mask should be widened as well
- EVT WideMaskVT =
- EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
- WidenVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
- }
+ // The mask should be widened as well
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorElementCount());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
SDValue Res = DAG.getMaskedLoad(
WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6921,30 +6911,25 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
SDValue Mask = MST->getMask();
EVT MaskVT = Mask.getValueType();
SDValue StVal = MST->getValue();
- SDLoc dl(N);
-
EVT VT = StVal.getValueType();
+ SDLoc dl(N);
if (OpNo == 1) {
- if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
- VT == MVT::nxv1i64) {
- EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
- EVT WideMaskVT =
- EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
- WidenVT.getVectorMinNumElements(), true);
- StVal = ModifyToType(StVal, WidenVT);
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ EVT WideVT;
+ if (VT.isScalableVector() && VT.getVectorMinNumElements() == 1 &&
+ VT.isInteger() && VT.getVectorElementType().isByteSized()) {
+ WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ StVal = ModifyToType(StVal, WideVT);
} else {
// Widen the value.
StVal = GetWidenedVector(StVal);
-
- // The mask should be widened as well.
- EVT WideVT = StVal.getValueType();
- EVT WideMaskVT =
- EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
- WideVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ // The mask should be widened as well
+ WideVT = StVal.getValueType();
}
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WideVT.getVectorElementCount());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
} else {
// Widen the mask.
EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
>From 3f0cabaae333aeb14adeed3b968d50d176f73cc9 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Fri, 19 Jul 2024 14:02:38 +0000
Subject: [PATCH 3/4] Removed extra file.
---
.gitreview | 3 ---
1 file changed, 3 deletions(-)
delete mode 100644 .gitreview
diff --git a/.gitreview b/.gitreview
deleted file mode 100644
index 9b4839c770879..0000000000000
--- a/.gitreview
+++ /dev/null
@@ -1,3 +0,0 @@
-[gerrit]
-defaultremote=origin
-defaultbranch=main
>From 4956512745a8ce4b8fe3309aa6743e8824e32905 Mon Sep 17 00:00:00 2001
From: Dinar Temirbulatov <Dinar.Temirbulatov at arm.com>
Date: Fri, 19 Jul 2024 15:08:35 +0100
Subject: [PATCH 4/4] Update
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Done.
Co-authored-by: Sander de Smalen <sander.desmalen at arm.com>
---
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index d75afee9b652e..b357d677ca368 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6940,11 +6940,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
StVal = ModifyToType(StVal, WideVT);
}
- assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
- : Mask.getValueType().getVectorNumElements()) ==
- (VT.isScalableVector()
- ? StVal.getValueType().getVectorMinNumElements()
- : StVal.getValueType().getVectorNumElements()) &&
+ assert(Mask.getValueType().getVectorMinNumElements() ==
+ StVal.getValueType().getVectorMinNumElements() &&
"Mask and data vectors should have the same number of elements");
return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
MST->getOffset(), Mask, MST->getMemoryVT(),
More information about the llvm-commits
mailing list