[llvm] [LoopIdiom] Add minimal support for using llvm.experimental.memset.pattern (PR #118632)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 4 05:27:40 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alex Bradbury (asb)

<details>
<summary>Changes</summary>

Although there's potential for teaching LoopIdiomRecognize to produce the llvm.experimental.experimental.memset.pattern intrinsic in some cases where it would bail out for memset_pattern16, this patch explicitly avoids that. Instead, the goal is to support emitting the intrinsic for exactly the cases where we currently support using the memset_pattern16 intrinsic, and to do so without excessively invasive refactoring. In the future you would imagine removing support for emitting the libcall in LoopIdiomRecognize, leaving it instead for PreISelIntrinsicLowering. Again, that is something for the future.

This support is gated behind the
`-loop-idiom-enable-memset-pattern-intrinsic` flag, so no functional change is intended unless that flag is specifically given.

I've copied the memset-pattern-tbaa.ll test case, though would be happy to clean it up and diverge from that base if preferred. This is the diff of this memset.pattern intrinsic test vs the libcall one: https://gist.github.com/asb/1cbd58ff04955a60d4fdffb1a2ca4ac1

---
Full diff: https://github.com/llvm/llvm-project/pull/118632.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+90-34) 
- (added) llvm/test/Transforms/LoopIdiom/memset-pattern-intrinsic.ll (+112) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 05cf638d3f09df..9b2051af92e704 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -132,6 +132,11 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
              "with -Os/-Oz"),
     cl::init(true), cl::Hidden);
 
+static cl::opt<bool> EnableMemsetPatternIntrinsic(
+    "loop-idiom-enable-memset-pattern-intrinsic",
+    cl::desc("Enable use of the memset_pattern intrinsic."), cl::init(false),
+    cl::Hidden);
+
 namespace {
 
 class LoopIdiomRecognize {
@@ -306,7 +311,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
   HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
   HasMemcpy = TLI->has(LibFunc_memcpy);
 
-  if (HasMemset || HasMemsetPattern || HasMemcpy)
+  if (HasMemset || HasMemsetPattern || EnableMemsetPatternIntrinsic ||
+      HasMemcpy)
     if (SE->hasLoopInvariantBackedgeTakenCount(L))
       return runOnCountableLoop();
 
@@ -463,8 +469,10 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
     // It looks like we can use SplatValue.
     return LegalStoreKind::Memset;
   }
-  if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
-      // Don't create memset_pattern16s with address spaces.
+  if (!UnorderedAtomic && (HasMemsetPattern || EnableMemsetPatternIntrinsic) &&
+      !DisableLIRP::Memset &&
+      // Don't create memset_pattern16s / memset.pattern intrinsics with
+      // address spaces.
       StorePtr->getType()->getPointerAddressSpace() == 0 &&
       getMemSetPatternValue(StoredVal, DL)) {
     // It looks like we can use PatternValue!
@@ -1064,53 +1072,101 @@ bool LoopIdiomRecognize::processLoopStridedStore(
     return Changed;
 
   // Okay, everything looks good, insert the memset.
+  // MemsetArg is the number of bytes for the memset and memset_pattern16
+  // libcalls, and the number of pattern repetitions if the memset.pattern
+  // intrinsic is being used.
+  Value *MemsetArg;
+  std::optional<int64_t> BytesWritten = std::nullopt;
+
+  if (PatternValue && EnableMemsetPatternIntrinsic) {
+    const SCEV *TripCountS =
+        SE->getTripCountFromExitCount(BECount, IntIdxTy, CurLoop);
+    if (!Expander.isSafeToExpand(TripCountS))
+      return Changed;
+    const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
+    if (!ConstStoreSize)
+      return Changed;
 
-  const SCEV *NumBytesS =
-      getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
-
-  // TODO: ideally we should still be able to generate memset if SCEV expander
-  // is taught to generate the dependencies at the latest point.
-  if (!Expander.isSafeToExpand(NumBytesS))
-    return Changed;
+    MemsetArg = Expander.expandCodeFor(TripCountS, IntIdxTy,
+                                       Preheader->getTerminator());
+    if (auto CI = dyn_cast<ConstantInt>(MemsetArg))
+      BytesWritten =
+          CI->getZExtValue() * ConstStoreSize->getValue()->getZExtValue();
+  } else {
+    const SCEV *NumBytesS =
+        getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
 
-  Value *NumBytes =
-      Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
+    // TODO: ideally we should still be able to generate memset if SCEV expander
+    // is taught to generate the dependencies at the latest point.
+    if (!Expander.isSafeToExpand(NumBytesS))
+      return Changed;
+    MemsetArg =
+        Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
+    if (auto CI = dyn_cast<ConstantInt>(MemsetArg))
+      BytesWritten = CI->getZExtValue();
+  }
+  assert(MemsetArg && "MemsetArg should have been set");
 
-  if (!SplatValue && !isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16))
+  if (!SplatValue && !(isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16) ||
+                       EnableMemsetPatternIntrinsic))
     return Changed;
 
   AAMDNodes AATags = TheStore->getAAMetadata();
   for (Instruction *Store : Stores)
     AATags = AATags.merge(Store->getAAMetadata());
-  if (auto CI = dyn_cast<ConstantInt>(NumBytes))
-    AATags = AATags.extendTo(CI->getZExtValue());
+  if (BytesWritten)
+    AATags = AATags.extendTo(BytesWritten.value());
   else
     AATags = AATags.extendTo(-1);
 
   CallInst *NewCall;
   if (SplatValue) {
     NewCall = Builder.CreateMemSet(
-        BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment),
+        BasePtr, SplatValue, MemsetArg, MaybeAlign(StoreAlignment),
         /*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias);
   } else {
-    assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16));
-    // Everything is emitted in default address space
-    Type *Int8PtrTy = DestInt8PtrTy;
-
-    StringRef FuncName = "memset_pattern16";
-    FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16,
-                            Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntIdxTy);
-    inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI);
-
-    // Otherwise we should form a memset_pattern16.  PatternValue is known to be
-    // an constant array of 16-bytes.  Plop the value into a mergable global.
-    GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
-                                            GlobalValue::PrivateLinkage,
-                                            PatternValue, ".memset_pattern");
-    GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
-    GV->setAlignment(Align(16));
-    Value *PatternPtr = GV;
-    NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
+    assert(isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16) ||
+           EnableMemsetPatternIntrinsic);
+    if (EnableMemsetPatternIntrinsic) {
+      // Everything is emitted in default address space
+
+      assert(isa<SCEVConstant>(StoreSizeSCEV) &&
+             "Expected constant store size");
+      llvm::Type *IntType = Builder.getIntNTy(
+          cast<SCEVConstant>(StoreSizeSCEV)->getValue()->getZExtValue() * 8);
+
+      llvm::Value *BitcastedValue = Builder.CreateBitCast(StoredVal, IntType);
+
+      // (Optional) Use the bitcasted value for further operations
+
+      // Create the call to the intrinsic
+      NewCall =
+          Builder.CreateIntrinsic(Intrinsic::experimental_memset_pattern,
+                                  {DestInt8PtrTy, IntType, IntIdxTy},
+                                  {BasePtr, BitcastedValue, MemsetArg,
+                                   ConstantInt::getFalse(M->getContext())});
+    } else {
+      // Everything is emitted in default address space
+      Type *Int8PtrTy = DestInt8PtrTy;
+
+      StringRef FuncName = "memset_pattern16";
+      FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16,
+                                              Builder.getVoidTy(), Int8PtrTy,
+                                              Int8PtrTy, IntIdxTy);
+      inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI);
+
+      // Otherwise we should form a memset_pattern16.  PatternValue is known to
+      // be an constant array of 16-bytes.  Plop the value into a mergable
+      // global.
+      GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
+                                              GlobalValue::PrivateLinkage,
+                                              PatternValue, ".memset_pattern");
+      GV->setUnnamedAddr(
+          GlobalValue::UnnamedAddr::Global); // Ok to merge these.
+      GV->setAlignment(Align(16));
+      Value *PatternPtr = GV;
+      NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, MemsetArg});
+    }
 
     // Set the TBAA info if present.
     if (AATags.TBAA)
diff --git a/llvm/test/Transforms/LoopIdiom/memset-pattern-intrinsic.ll b/llvm/test/Transforms/LoopIdiom/memset-pattern-intrinsic.ll
new file mode 100644
index 00000000000000..104ee7b7dccd44
--- /dev/null
+++ b/llvm/test/Transforms/LoopIdiom/memset-pattern-intrinsic.ll
@@ -0,0 +1,112 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes="loop-idiom" -loop-idiom-enable-memset-pattern-intrinsic < %s -S | FileCheck %s
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
+
+target triple = "x86_64-apple-darwin10.0.0"
+
+; Test case copied from memset-pattern-tbaa.ll.
+
+define dso_local void @double_memset(ptr nocapture %p) {
+; CHECK-LABEL: @double_memset(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.experimental.memset.pattern.p0.i64.i64(ptr [[P:%.*]], i64 4614256650576692846, i64 16, i1 false), !tbaa [[TBAA0:![0-9]+]]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[I_07:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr inbounds double, ptr [[P]], i64 [[I_07]]
+; CHECK-NEXT:    [[INC]] = add nuw nsw i64 [[I_07]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INC]], 16
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %i.07 = phi i64 [ %inc, %for.body ], [ 0, %entry ]
+  %ptr1 = getelementptr inbounds double, ptr %p, i64 %i.07
+  store double 3.14159e+00, ptr %ptr1, align 1, !tbaa !5
+  %inc = add nuw nsw i64 %i.07, 1
+  %exitcond.not = icmp eq i64 %inc, 16
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+
+define dso_local void @struct_memset(ptr nocapture %p) {
+; CHECK-LABEL: @struct_memset(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.experimental.memset.pattern.p0.i64.i64(ptr [[P:%.*]], i64 4614256650576692846, i64 16, i1 false), !tbaa [[TBAA4:![0-9]+]]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[I_07:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr inbounds double, ptr [[P]], i64 [[I_07]]
+; CHECK-NEXT:    [[INC]] = add nuw nsw i64 [[I_07]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INC]], 16
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %i.07 = phi i64 [ %inc, %for.body ], [ 0, %entry ]
+  %ptr1 = getelementptr inbounds double, ptr %p, i64 %i.07
+  store double 3.14159e+00, ptr %ptr1, align 1, !tbaa !10
+  %inc = add nuw nsw i64 %i.07, 1
+  %exitcond.not = icmp eq i64 %inc, 16
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define dso_local void @var_memset(ptr nocapture %p, i64 %len) {
+; CHECK-LABEL: @var_memset(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.experimental.memset.pattern.p0.i64.i64(ptr [[P:%.*]], i64 4614256650576692846, i64 [[LEN:%.*]], i1 false)
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[I_07:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr inbounds double, ptr [[P]], i64 [[I_07]]
+; CHECK-NEXT:    [[INC]] = add nuw nsw i64 [[I_07]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INC]], [[LEN]]
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]]
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %i.07 = phi i64 [ %inc, %for.body ], [ 0, %entry ]
+  %ptr1 = getelementptr inbounds double, ptr %p, i64 %i.07
+  store double 3.14159e+00, ptr %ptr1, align 1, !tbaa !10
+  %inc = add nuw nsw i64 %i.07, 1
+  %exitcond.not = icmp eq i64 %inc, %len
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+!5 = !{!6, !6, i64 0}
+!6 = !{!"double", !7, i64 0}
+!7 = !{!"omnipotent char", !8, i64 0}
+!8 = !{!"Simple C++ TBAA"}
+
+!15 = !{!8, i64 0, !"omnipotent char"}
+!17 = !{!15, i64 8, !"double"}
+!9 = !{!15, i64 32, !"_ZTS1A", !17, i64 0, i64 8, !17, i64 8, i64 8, !17, i64 16, i64 8, !17, i64 24, i64 8}
+!10 = !{!9, !17, i64 0, i64 1}
+
+!18 = !{!19, !20, i64 0}
+!19 = !{!"A", !20, i64 0, !22, i64 8}
+!20 = !{!"any pointer", !7, i64 0}
+!21 = !{!22, !20, i64 0}
+!22 = !{!"B", !20, i64 0}

``````````

</details>


https://github.com/llvm/llvm-project/pull/118632


More information about the llvm-commits mailing list