[llvm] [AArch64] Fix lowring error for masked load/store integer scalable ve… (PR #99354)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 19 06:54:21 PDT 2024


================
@@ -6912,32 +6911,40 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue Mask = MST->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue StVal = MST->getValue();
+  EVT VT = StVal.getValueType();
   SDLoc dl(N);
 
   if (OpNo == 1) {
-    // Widen the value.
-    StVal = GetWidenedVector(StVal);
-
-    // The mask should be widened as well.
-    EVT WideVT = StVal.getValueType();
-    EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                      MaskVT.getVectorElementType(),
-                                      WideVT.getVectorNumElements());
+    EVT WideVT;
+    if (VT.isScalableVector() && VT.getVectorMinNumElements() == 1 &&
+        VT.isInteger() && VT.getVectorElementType().isByteSized()) {
+      WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+      StVal = ModifyToType(StVal, WideVT);
+    } else {
+      // Widen the value.
+      StVal = GetWidenedVector(StVal);
+      // The mask should be widened as well
+      WideVT = StVal.getValueType();
+    }
+    EVT WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WideVT.getVectorElementCount());
     Mask = ModifyToType(Mask, WideMaskVT, true);
   } else {
     // Widen the mask.
     EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
     Mask = ModifyToType(Mask, WideMaskVT, true);
 
-    EVT ValueVT = StVal.getValueType();
-    EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
-                                  ValueVT.getVectorElementType(),
+    EVT WideVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
                                   WideMaskVT.getVectorNumElements());
     StVal = ModifyToType(StVal, WideVT);
   }
 
-  assert(Mask.getValueType().getVectorNumElements() ==
-         StVal.getValueType().getVectorNumElements() &&
+  assert((VT.isScalableVector() ? Mask.getValueType().getVectorMinNumElements()
+                                : Mask.getValueType().getVectorNumElements()) ==
+             (VT.isScalableVector()
+                  ? StVal.getValueType().getVectorMinNumElements()
+                  : StVal.getValueType().getVectorNumElements()) &&
----------------
sdesmalen-arm wrote:

```suggestion
  assert(Mask.getValueType().getVectorMinNumElements() ==
          StVal.getValueType().getVectorMinNumElements() &&
```

https://github.com/llvm/llvm-project/pull/99354


More information about the llvm-commits mailing list