[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 17 09:50:00 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-selectiondag
Author: Dinar Temirbulatov (dtemirbulatov)
<details>
<summary>Changes</summary>
…ctors one element types.
Proposed change fixes error while lowering masked load/store integer scalable vector types nxv1i8, nxv1i16, nxv1i32, nxv1i64.
---
Full diff: https://github.com/llvm/llvm-project/pull/99354.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+42-20)
- (added) llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll (+142)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1a575abbc16f4..bb210fcd7bdab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5749,18 +5749,27 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
- EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+ EVT VT = N->getValueType(0);
+ EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue Mask = N->getMask();
EVT MaskVT = Mask.getValueType();
SDValue PassThru = GetWidenedVector(N->getPassThru());
ISD::LoadExtType ExtType = N->getExtensionType();
SDLoc dl(N);
- // The mask should be widened as well
- EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
- MaskVT.getVectorElementType(),
- WidenVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+ VT == MVT::nxv1i64) {
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorMinNumElements(), true);
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ } else {
+ // The mask should be widened as well
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorNumElements());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ }
SDValue Res = DAG.getMaskedLoad(
WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
@@ -6914,30 +6923,43 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
SDValue StVal = MST->getValue();
SDLoc dl(N);
- if (OpNo == 1) {
- // Widen the value.
- StVal = GetWidenedVector(StVal);
+ EVT VT = StVal.getValueType();
- // The mask should be widened as well.
- EVT WideVT = StVal.getValueType();
- EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
- MaskVT.getVectorElementType(),
- WideVT.getVectorNumElements());
- Mask = ModifyToType(Mask, WideMaskVT, true);
+ if (OpNo == 1) {
+ if (VT == MVT::nxv1i8 || VT == MVT::nxv1i16 || VT == MVT::nxv1i32 ||
+ VT == MVT::nxv1i64) {
+ EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WidenVT.getVectorMinNumElements(), true);
+ StVal = ModifyToType(StVal, WidenVT);
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ } else {
+ // Widen the value.
+ StVal = GetWidenedVector(StVal);
+
+ // The mask should be widened as well.
+ EVT WideVT = StVal.getValueType();
+ EVT WideMaskVT =
+ EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+ WideVT.getVectorNumElements());
+ Mask = ModifyToType(Mask, WideMaskVT, true);
+ }
} else {
// Widen the mask.
EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
Mask = ModifyToType(Mask, WideMaskVT, true);
- EVT ValueVT = StVal.getValueType();
- EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
- ValueVT.getVectorElementType(),
+ EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
WideMaskVT.getVectorNumElements());
StVal = ModifyToType(StVal, WideVT);
}
- assert(Mask.getValueType().getVectorNumElements() ==
- StVal.getValueType().getVectorNumElements() &&
+ assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+ : Mask.getValueType().getVectorNumElements()) ==
+ (VT.isScalableVector()
+ ? StVal.getValueType().getVectorMinNumElements()
+ : StVal.getValueType().getVectorNumElements()) &&
"Mask and data vectors should have the same number of elements");
return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
new file mode 100644
index 0000000000000..d7820273a6f78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-nxv1-load-store.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s --check-prefix CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @store_i8(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i8> %val) #0 {
+; CHECK-LABEL: store_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i16(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i16> %val) #0 {
+; CHECK-LABEL: store_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: st1h { z0.h }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i32(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i32> %val) #0 {
+; CHECK-LABEL: store_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @store_i64(<vscale x 1 x i1> %pred, ptr %x, i64 %base, <vscale x 1 x i64> %val) #0 {
+; CHECK-LABEL: store_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ call void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64> %val, ptr %x, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define void @load_store_i32(<vscale x 1 x i1> %pred, ptr %x, ptr %y) #0 {
+; CHECK-LABEL: load_store_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT: st1w { z0.s }, p0, [x1]
+; CHECK-NEXT: ret
+ %load = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 2, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+ call void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32> %load, ptr %y, i32 1, <vscale x 1 x i1> %pred)
+ ret void
+}
+
+define <vscale x 1 x i8> @load_i8(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: uzp1 p1.h, p1.h, p1.h
+; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i8> undef)
+ ret <vscale x 1 x i8> %ret
+}
+
+define <vscale x 1 x i16> @load_i16(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: uzp1 p1.s, p1.s, p1.s
+; CHECK-NEXT: uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i16> undef)
+ ret <vscale x 1 x i16> %ret
+}
+
+define <vscale x 1 x i32> @load_i32(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: uzp1 p1.d, p1.d, p1.d
+; CHECK-NEXT: uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i32> undef)
+ ret <vscale x 1 x i32> %ret
+}
+
+define <vscale x 1 x i64> @load_i64(<vscale x 1 x i1> %pred, ptr %x) #0 {
+; CHECK-LABEL: load_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: pfalse p1.b
+; CHECK-NEXT: uzp1 p0.d, p0.d, p1.d
+; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0]
+; CHECK-NEXT: ret
+ %ret = call <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr %x, i32 1, <vscale x 1 x i1> %pred, <vscale x 1 x i64> undef)
+ ret <vscale x 1 x i64> %ret
+}
+
+declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
+declare <vscale x 1 x i16> @llvm.masked.load.nxv1i16(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i16>)
+declare <vscale x 1 x i32> @llvm.masked.load.nxv1i32(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i32>)
+declare <vscale x 1 x i64> @llvm.masked.load.nxv1i64(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare void @llvm.masked.store.nxv1i8.p0nxv1i8(<vscale x 1 x i8>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i16.p0nxv1i16(<vscale x 1 x i16>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i32.p0nxv1i32(<vscale x 1 x i32>, ptr, i32 immarg, <vscale x 1 x i1>)
+declare void @llvm.masked.store.nxv1i64.p0nxv1i64(<vscale x 1 x i64>, ptr, i32 immarg, <vscale x 1 x i1>)
+
+attributes #0 = { "target-features"="+sve" }
``````````
</details>
https://github.com/llvm/llvm-project/pull/99354
More information about the llvm-commits
mailing list