[llvm] 368cb42 - [ASAN] Support memory checks on scalable vector typed masked load and store

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 10 16:20:32 PST 2023


Author: Philip Reames
Date: 2023-03-10T16:20:25-08:00
New Revision: 368cb421c396edcc4fa2a708f8ebb0cc56925424

URL: https://github.com/llvm/llvm-project/commit/368cb421c396edcc4fa2a708f8ebb0cc56925424
DIFF: https://github.com/llvm/llvm-project/commit/368cb421c396edcc4fa2a708f8ebb0cc56925424.diff

LOG: [ASAN] Support memory checks on scalable vector typed masked load and store

This takes the approach of using the loop based formation for scalable vectors only. We could potentially use the loop form for fixed vectors only, but we'd loose the unroll and specialize on constant vector logic which is already present. I don't have a strong opinion on whether the existing logic is worthwhile, I kept it mostly to minimize test churn.

Worth noting is that there is a better lowering available. The plain vector lowering appears to check only the first and last byte. By analogy, we should be able to check only the first active and last active byte in the masked op. This is a more invasive change to asan, and I decided simply supporting scalable vectors at all was a better starting place.

Differential Revision: https://reviews.llvm.org/D145198

Added: 
    

Modified: 
    llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
    llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index 2f9ef7b688a5c..0b07c10ae8904 100644
--- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -1439,6 +1439,37 @@ static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I,
                                          IsWrite, nullptr, UseCalls, Exp);
 }
 
+static void SplitBlockAndInsertSimpleForLoop(Value *End,
+                                             Instruction *SplitBefore,
+                                             Instruction *&BodyIP,
+                                             Value *&Index) {
+  BasicBlock *LoopPred = SplitBefore->getParent();
+  BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore);
+  BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore);
+
+  auto *Ty = End->getType();
+  auto &DL = SplitBefore->getModule()->getDataLayout();
+  const unsigned Bitwidth = DL.getTypeSizeInBits(Ty);
+
+  IRBuilder<> Builder(LoopBody->getTerminator());
+  auto *IV = Builder.CreatePHI(Ty, 2, "iv");
+  auto *IVNext =
+    Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next",
+                      /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
+  auto *IVCheck = Builder.CreateICmpEQ(IVNext, End,
+                                       IV->getName() + ".check");
+  Builder.CreateCondBr(IVCheck, LoopExit, LoopBody);
+  LoopBody->getTerminator()->eraseFromParent();
+
+  // Populate the IV PHI.
+  IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred);
+  IV->addIncoming(IVNext, LoopBody);
+
+  BodyIP = LoopBody->getFirstNonPHI();
+  Index = IV;
+}
+
+
 static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass,
                                         const DataLayout &DL, Type *IntptrTy,
                                         Value *Mask, Instruction *I,
@@ -1446,36 +1477,65 @@ static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass,
                                         unsigned Granularity, Type *OpType,
                                         bool IsWrite, Value *SizeArgument,
                                         bool UseCalls, uint32_t Exp) {
-  auto *VTy = cast<FixedVectorType>(OpType);
-  uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
-  unsigned Num = VTy->getNumElements();
+  auto *VTy = cast<VectorType>(OpType);
+
+  TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
   auto Zero = ConstantInt::get(IntptrTy, 0);
-  for (unsigned Idx = 0; Idx < Num; ++Idx) {
-    Value *InstrumentedAddress = nullptr;
-    Instruction *InsertBefore = I;
-    if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
-      // dyn_cast as we might get UndefValue
-      if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
-        if (Masked->isZero())
-          // Mask is constant false, so no instrumentation needed.
-          continue;
-        // If we have a true or undef value, fall through to doInstrumentAddress
-        // with InsertBefore == I
+
+  // For fixed length vectors, it's legal to fallthrough into the generic loop
+  // lowering below, but we chose to unroll and specialize instead.  We might want
+  // to revisit this heuristic decision.
+  if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
+    unsigned Num = FVTy->getNumElements();
+    for (unsigned Idx = 0; Idx < Num; ++Idx) {
+      Value *InstrumentedAddress = nullptr;
+      Instruction *InsertBefore = I;
+      if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
+        // dyn_cast as we might get UndefValue
+        if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
+          if (Masked->isZero())
+            // Mask is constant false, so no instrumentation needed.
+            continue;
+          // If we have a true or undef value, fall through to doInstrumentAddress
+          // with InsertBefore == I
+        }
+      } else {
+        IRBuilder<> IRB(I);
+        Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
+        Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
+        InsertBefore = ThenTerm;
       }
-    } else {
-      IRBuilder<> IRB(I);
-      Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
-      Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
-      InsertBefore = ThenTerm;
-    }
 
-    IRBuilder<> IRB(InsertBefore);
-    InstrumentedAddress =
+      IRBuilder<> IRB(InsertBefore);
+      InstrumentedAddress =
         IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)});
-    doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
-                        Granularity, TypeSize::Fixed(ElemTypeSize), IsWrite,
-                        SizeArgument, UseCalls, Exp);
+      doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
+                          Granularity, ElemTypeSize, IsWrite,
+                          SizeArgument, UseCalls, Exp);
+    }
+    return;
   }
+
+
+  IRBuilder<> IRB(I);
+  Constant *MinNumElem =
+    ConstantInt::get(IntptrTy, VTy->getElementCount().getKnownMinValue());
+  assert(isa<ScalableVectorType>(VTy) && "generalize if reused for fixed length");
+  Value *NumElements = IRB.CreateVScale(MinNumElem);
+
+  Instruction *BodyIP;
+  Value *Index;
+  SplitBlockAndInsertSimpleForLoop(NumElements, I, BodyIP, Index);
+
+  IRB.SetInsertPoint(BodyIP);
+  Value *MaskElem = IRB.CreateExtractElement(Mask, Index);
+  Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, BodyIP, false);
+  IRB.SetInsertPoint(ThenTerm);
+
+  Value *InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index});
+  doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), InstrumentedAddress, Alignment,
+                      Granularity, ElemTypeSize, IsWrite, SizeArgument,
+                      UseCalls, Exp);
 }
 
 void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,

diff  --git a/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll b/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll
index a6cde287a8120..45a7977865652 100644
--- a/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll
+++ b/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll
@@ -308,3 +308,68 @@ define <4 x float> @load.v4f32.1001.after.full.load(ptr %p, <4 x float> %arg) sa
   %res2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %p, i32 4, <4 x i1> <i1 false, i1 false, i1 false, i1 true>, <4 x float> %arg)
   ret <4 x float> %res2
 }
+
+;; Scalable vector tests
+;; ---------------------------
+declare <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
+declare void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float>, ptr, i32, <vscale x 4 x i1>)
+
+define <vscale x 4 x float> @scalable.load.nxv4f32(ptr %p, <vscale x 4 x i1> %mask) sanitize_address {
+; CHECK-LABEL: @scalable.load.nxv4f32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    br label [[DOTSPLIT:%.*]]
+; CHECK:       .split:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
+; CHECK-NEXT:    br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
+; CHECK:       4:
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
+; CHECK-NEXT:    [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+; CHECK-NEXT:    call void @__asan_load4(i64 [[TMP6]])
+; CHECK-NEXT:    br label [[TMP7]]
+; CHECK:       7:
+; CHECK-NEXT:    [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT:    [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
+; CHECK:       .split.split:
+; CHECK-NEXT:    [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> undef)
+; CHECK-NEXT:    ret <vscale x 4 x float> [[RES]]
+;
+; DISABLED-LABEL: @scalable.load.nxv4f32(
+; DISABLED-NEXT:    [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> undef)
+; DISABLED-NEXT:    ret <vscale x 4 x float> [[RES]]
+;
+  %res = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %p, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef)
+  ret <vscale x 4 x float> %res
+}
+
+define void @scalable.store.nxv4f32(ptr %p, <vscale x 4 x float> %arg, <vscale x 4 x i1> %mask) sanitize_address {
+; CHECK-LABEL: @scalable.store.nxv4f32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    br label [[DOTSPLIT:%.*]]
+; CHECK:       .split:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
+; CHECK-NEXT:    br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
+; CHECK:       4:
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
+; CHECK-NEXT:    [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
+; CHECK-NEXT:    call void @__asan_store4(i64 [[TMP6]])
+; CHECK-NEXT:    br label [[TMP7]]
+; CHECK:       7:
+; CHECK-NEXT:    [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT:    [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
+; CHECK:       .split.split:
+; CHECK-NEXT:    tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]])
+; CHECK-NEXT:    ret void
+;
+; DISABLED-LABEL: @scalable.store.nxv4f32(
+; DISABLED-NEXT:    tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
+; DISABLED-NEXT:    ret void
+;
+  tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %arg, ptr %p, i32 4, <vscale x 4 x i1> %mask)
+  ret void
+}


        


More information about the llvm-commits mailing list