[llvm] 7fb66d4 - [MemCpyOpt] Fix a variety of scalable-type crashes

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 8 03:31:36 PDT 2021


Author: Fraser Cormack
Date: 2021-09-08T11:21:36+01:00
New Revision: 7fb66d4035960b3c2797eae73d79b8478ff0348e

URL: https://github.com/llvm/llvm-project/commit/7fb66d4035960b3c2797eae73d79b8478ff0348e
DIFF: https://github.com/llvm/llvm-project/commit/7fb66d4035960b3c2797eae73d79b8478ff0348e.diff

LOG: [MemCpyOpt] Fix a variety of scalable-type crashes

This patch fixes a variety of crashes resulting from the `MemCpyOptPass`
casting `TypeSize` to a constant integer, whether implicitly or
explicitly.

Since the `MemsetRanges` requires a constant size to work, all but one
of the fixes in this patch simply involve skipping the various
optimizations for scalable types as cleanly as possible.

The optimization of `byval` parameters, however, has been updated to
work on scalable types in theory. In practice, this optimization is only
valid when the length of the `memcpy` is known to be larger than the
scalable type size, which is currently never the case. This could
perhaps be done in the future using the `vscale_range` attribute.

Some implicit casts have been left as they were, under the knowledge
they are only called on aggregate types. These should never be
scalably-sized.

Reviewed By: nikic, tra

Differential Revision: https://reviews.llvm.org/D109329

Added: 
    llvm/test/Transforms/MemCpyOpt/vscale-crashes.ll

Modified: 
    llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
    llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
index 76dbec47fbb8b..3a4db13d670a8 100644
--- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
+++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
@@ -62,7 +62,7 @@ class MemCpyOptPass : public PassInfoMixin<MemCpyOptPass> {
   bool processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI);
   bool processMemMove(MemMoveInst *M);
   bool performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore,
-                            Value *cpyDst, Value *cpySrc, uint64_t cpyLen,
+                            Value *cpyDst, Value *cpySrc, TypeSize cpyLen,
                             Align cpyAlign, CallInst *C);
   bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep);
   bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet);

diff  --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 49070cf2e6b7a..67335a45fb58f 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -178,9 +178,9 @@ class MemsetRanges {
   }
 
   void addStore(int64_t OffsetFromFirst, StoreInst *SI) {
-    int64_t StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType());
-
-    addRange(OffsetFromFirst, StoreSize, SI->getPointerOperand(),
+    TypeSize StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType());
+    assert(!StoreSize.isScalable() && "Can't track scalable-typed stores");
+    addRange(OffsetFromFirst, StoreSize.getFixedSize(), SI->getPointerOperand(),
              SI->getAlign().value(), SI);
   }
 
@@ -363,6 +363,11 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
                                                  Value *ByteVal) {
   const DataLayout &DL = StartInst->getModule()->getDataLayout();
 
+  // We can't track scalable types
+  if (StoreInst *SI = dyn_cast<StoreInst>(StartInst))
+    if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
+      return nullptr;
+
   // Okay, so we now have a single store that can be splatable.  Scan to find
   // all subsequent stores of the same value to offset from the same pointer.
   // Join these together into ranges, so we can decide whether contiguous blocks
@@ -416,6 +421,10 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
       if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
         break;
 
+      // We can't track ranges involving scalable types.
+      if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+        break;
+
       // Check to see if this stored value is of the same byte-splattable value.
       Value *StoredByte = isBytewiseValue(StoredVal, DL);
       if (isa<UndefValue>(ByteVal) && StoredByte)
@@ -836,7 +845,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
 /// the call write its result directly into the destination of the memcpy.
 bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
                                          Instruction *cpyStore, Value *cpyDest,
-                                         Value *cpySrc, uint64_t cpyLen,
+                                         Value *cpySrc, TypeSize cpySize,
                                          Align cpyAlign, CallInst *C) {
   // The general transformation to keep in mind is
   //
@@ -852,6 +861,10 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   // src only holds uninitialized values at the moment of the call, meaning that
   // the memcpy can be discarded rather than moved.
 
+  // We can't optimize scalable types.
+  if (cpySize.isScalable())
+    return false;
+
   // Lifetime marks shouldn't be operated on.
   if (Function *F = C->getCalledFunction())
     if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start)
@@ -870,13 +883,13 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   uint64_t srcSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()) *
                      srcArraySize->getZExtValue();
 
-  if (cpyLen < srcSize)
+  if (cpySize < srcSize)
     return false;
 
   // Check that accessing the first srcSize bytes of dest will not cause a
   // trap.  Otherwise the transform is invalid since it might cause a trap
   // to occur earlier than it otherwise would.
-  if (!isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpyLen),
+  if (!isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize),
                                           DL, C, DT))
     return false;
 
@@ -1370,8 +1383,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
             // of conservatively taking the minimum?
             Align Alignment = std::min(M->getDestAlign().valueOrOne(),
                                        M->getSourceAlign().valueOrOne());
-            if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(),
-                                     CopySize->getZExtValue(), Alignment, C)) {
+            if (performCallSlotOptzn(
+                    M, M, M->getDest(), M->getSource(),
+                    TypeSize::getFixed(CopySize->getZExtValue()), Alignment,
+                    C)) {
               LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n"
                                 << "    call: " << *C << "\n"
                                 << "    memcpy: " << *M << "\n");
@@ -1435,7 +1450,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
   // Find out what feeds this byval argument.
   Value *ByValArg = CB.getArgOperand(ArgNo);
   Type *ByValTy = CB.getParamByValType(ArgNo);
-  uint64_t ByValSize = DL.getTypeAllocSize(ByValTy);
+  TypeSize ByValSize = DL.getTypeAllocSize(ByValTy);
   MemoryLocation Loc(ByValArg, LocationSize::precise(ByValSize));
   MemoryUseOrDef *CallAccess = MSSA->getMemoryAccess(&CB);
   if (!CallAccess)
@@ -1455,7 +1470,8 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
 
   // The length of the memcpy must be larger or equal to the size of the byval.
   ConstantInt *C1 = dyn_cast<ConstantInt>(MDep->getLength());
-  if (!C1 || C1->getValue().getZExtValue() < ByValSize)
+  if (!C1 || !TypeSize::isKnownGE(
+                 TypeSize::getFixed(C1->getValue().getZExtValue()), ByValSize))
     return false;
 
   // Get the alignment of the byval.  If the call doesn't specify the alignment,

diff  --git a/llvm/test/Transforms/MemCpyOpt/vscale-crashes.ll b/llvm/test/Transforms/MemCpyOpt/vscale-crashes.ll
new file mode 100644
index 0000000000000..6b81fb52c631a
--- /dev/null
+++ b/llvm/test/Transforms/MemCpyOpt/vscale-crashes.ll
@@ -0,0 +1,101 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -memcpyopt -S -verify-memoryssa | FileCheck %s
+
+; Check that a call featuring a scalable-vector byval argument fed by a memcpy
+; doesn't crash the compiler. It previously assumed the byval type's size could
+; be represented as a known constant amount.
+define void @byval_caller(i8 *%P) {
+; CHECK-LABEL: @byval_caller(
+; CHECK-NEXT:    [[A:%.*]] = alloca i8, align 1
+; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[A]], i8* align 4 [[P:%.*]], i64 8, i1 false)
+; CHECK-NEXT:    [[VA:%.*]] = bitcast i8* [[A]] to <vscale x 1 x i8>*
+; CHECK-NEXT:    call void @byval_callee(<vscale x 1 x i8>* byval(<vscale x 1 x i8>) align 1 [[VA]])
+; CHECK-NEXT:    ret void
+;
+  %a = alloca i8
+  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 %a, i8* align 4 %P, i64 8, i1 false)
+  %va = bitcast i8* %a to <vscale x 1 x i8>*
+  call void @byval_callee(<vscale x 1 x i8>* align 1 byval(<vscale x 1 x i8>) %va)
+  ret void
+}
+
+declare void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4, i8* align 4, i64, i1)
+declare void @byval_callee(<vscale x 1 x i8>* align 1 byval(<vscale x 1 x i8>))
+
+; Check that two scalable-vector stores (overlapping, with a constant offset)
+; do not crash the compiler when checked whether or not they can be merged into
+; a single memset. There was previously an assumption that the stored values'
+; sizes could be represented by a known constant amount.
+define void @merge_stores_both_scalable(<vscale x 1 x i8>* %ptr) {
+; CHECK-LABEL: @merge_stores_both_scalable(
+; CHECK-NEXT:    store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[PTRI8:%.*]] = bitcast <vscale x 1 x i8>* [[PTR]] to i8*
+; CHECK-NEXT:    [[PTR_NEXT:%.*]] = getelementptr i8, i8* [[PTRI8]], i64 1
+; CHECK-NEXT:    [[PTR_NEXT_2:%.*]] = bitcast i8* [[PTR_NEXT]] to <vscale x 1 x i8>*
+; CHECK-NEXT:    store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* [[PTR_NEXT_2]], align 1
+; CHECK-NEXT:    ret void
+;
+  store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* %ptr
+  %ptri8 = bitcast <vscale x 1 x i8>* %ptr to i8*
+  %ptr.next = getelementptr i8, i8* %ptri8, i64 1
+  %ptr.next.2 = bitcast i8* %ptr.next to <vscale x 1 x i8>*
+  store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* %ptr.next.2
+  ret void
+}
+
+; As above, but where the base is scalable but the subsequent store(s) are not.
+define void @merge_stores_first_scalable(<vscale x 1 x i8>* %ptr) {
+; CHECK-LABEL: @merge_stores_first_scalable(
+; CHECK-NEXT:    store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[PTRI8:%.*]] = bitcast <vscale x 1 x i8>* [[PTR]] to i8*
+; CHECK-NEXT:    [[PTR_NEXT:%.*]] = getelementptr i8, i8* [[PTRI8]], i64 1
+; CHECK-NEXT:    store i8 0, i8* [[PTR_NEXT]], align 1
+; CHECK-NEXT:    ret void
+;
+  store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* %ptr
+  %ptri8 = bitcast <vscale x 1 x i8>* %ptr to i8*
+  %ptr.next = getelementptr i8, i8* %ptri8, i64 1
+  store i8 zeroinitializer, i8* %ptr.next
+  ret void
+}
+
+; As above, but where the base is not scalable but the subsequent store(s) are.
+define void @merge_stores_second_scalable(i8* %ptr) {
+; CHECK-LABEL: @merge_stores_second_scalable(
+; CHECK-NEXT:    store i8 0, i8* [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[PTR_NEXT:%.*]] = getelementptr i8, i8* [[PTR]], i64 1
+; CHECK-NEXT:    [[PTR_NEXT_2:%.*]] = bitcast i8* [[PTR_NEXT]] to <vscale x 1 x i8>*
+; CHECK-NEXT:    store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* [[PTR_NEXT_2]], align 1
+; CHECK-NEXT:    ret void
+;
+  store i8 zeroinitializer, i8* %ptr
+  %ptr.next = getelementptr i8, i8* %ptr, i64 1
+  %ptr.next.2 = bitcast i8* %ptr.next to <vscale x 1 x i8>*
+  store <vscale x 1 x i8> zeroinitializer, <vscale x 1 x i8>* %ptr.next.2
+  ret void
+}
+
+; Check that the call-slot optimization doesn't crash when encountering scalable types.
+define void @callslotoptzn(<vscale x 4 x float> %val, <vscale x 4 x float>* %out) {
+; CHECK-LABEL: @callslotoptzn(
+; CHECK-NEXT:    [[ALLOC:%.*]] = alloca <vscale x 4 x float>, align 16
+; CHECK-NEXT:    [[IDX:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
+; CHECK-NEXT:    [[BALLOC:%.*]] = getelementptr inbounds <vscale x 4 x float>, <vscale x 4 x float>* [[ALLOC]], i64 0, i64 0
+; CHECK-NEXT:    [[STRIDE:%.*]] = getelementptr inbounds float, float* [[BALLOC]], <vscale x 4 x i32> [[IDX]]
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4f32.nxv4p0f32(<vscale x 4 x float> [[VAL:%.*]], <vscale x 4 x float*> [[STRIDE]], i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    [[LI:%.*]] = load <vscale x 4 x float>, <vscale x 4 x float>* [[ALLOC]], align 4
+; CHECK-NEXT:    store <vscale x 4 x float> [[LI]], <vscale x 4 x float>* [[OUT:%.*]], align 4
+; CHECK-NEXT:    ret void
+;
+  %alloc = alloca <vscale x 4 x float>, align 16
+  %idx = tail call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
+  %balloc = getelementptr inbounds <vscale x 4 x float>, <vscale x 4 x float>* %alloc, i64 0, i64 0
+  %stride = getelementptr inbounds float, float* %balloc, <vscale x 4 x i32> %idx
+  call void @llvm.masked.scatter.nxv4f32.nxv4p0f32(<vscale x 4 x float> %val, <vscale x 4 x float*> %stride, i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+  %li = load <vscale x 4 x float>, <vscale x 4 x float>* %alloc, align 4
+  store <vscale x 4 x float> %li, <vscale x 4 x float>* %out, align 4
+  ret void
+}
+
+declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
+declare void @llvm.masked.scatter.nxv4f32.nxv4p0f32(<vscale x 4 x float> , <vscale x 4 x float*> , i32, <vscale x 4 x i1>)


        


More information about the llvm-commits mailing list