[llvm] Handle scalable store size in MemCpyOptimizer (PR #118957)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 08:51:00 PST 2024


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/118957

>From 7e2d60348850619fb7b0c8a88e92ab103f907d34 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Dec 2024 11:08:21 +0000
Subject: [PATCH 1/2] Handle scalable store size in MemCpyOptimizer

The compiler crashes with an ICE when it tries to create a `memset` with
scalable size.
---
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp |  3 +-
 .../CodeGen/AArch64/memset-scalable-size.ll   | 56 +++++++++++++++++++
 2 files changed, 58 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/AArch64/memset-scalable-size.ll

diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 0cba5d077da62b..fc5f6ff2b7f377 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -800,8 +800,9 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
     // in subsequent passes.
     auto *T = V->getType();
     if (T->isAggregateType()) {
-      uint64_t Size = DL.getTypeStoreSize(T);
       IRBuilder<> Builder(SI);
+      Value *Size =
+          Builder.CreateTypeSize(Builder.getInt64Ty(), DL.getTypeStoreSize(T));
       auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size,
                                      SI->getAlign());
       M->copyMetadata(*SI, LLVMContext::MD_DIAssignID);
diff --git a/llvm/test/CodeGen/AArch64/memset-scalable-size.ll b/llvm/test/CodeGen/AArch64/memset-scalable-size.ll
new file mode 100644
index 00000000000000..8ea6330f235a69
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/memset-scalable-size.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=memcpyopt < %s | FileCheck %s
+target triple = "aarch64-unknown-linux"
+
+define void @f0() {
+; CHECK-LABEL: define void @f0() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 16 x i1>, <vscale x 16 x i1> }, align 2
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 2 [[P]], i8 0, i64 [[TMP1]], i1 false)
+; CHECK-NEXT:    call void @g(ptr [[P]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p = alloca { <vscale x 16 x i1>, <vscale x 16 x i1>}, align 2
+  store { <vscale x 16 x i1>, <vscale x 16 x i1> } zeroinitializer, ptr %p, align 2
+  call void @g(ptr %p)
+  ret void
+}
+
+define void @f1() {
+; CHECK-LABEL: define void @f1() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> }, align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 48
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 16 [[P]], i8 0, i64 [[TMP1]], i1 false)
+; CHECK-NEXT:    call void @g(ptr [[P]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p = alloca {<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> }, align 16
+  store {<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } zeroinitializer, ptr %p, align 16
+  call void @g(ptr %p)
+  ret void
+}
+
+define void @f2() {
+; CHECK-LABEL: define void @f2() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> }, align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 192
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 16 [[P]], i8 0, i64 [[TMP1]], i1 false)
+; CHECK-NEXT:    call void @g(ptr [[P]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p = alloca {<vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> }, align 16
+  store {<vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> } zeroinitializer, ptr %p, align 16
+  call void @g(ptr %p)
+  ret void
+}
+
+declare void @g(ptr)

>From eca59d2a9d990d4dafca1dd9714f9bfeb851ea4c Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Dec 2024 16:28:38 +0000
Subject: [PATCH 2/2] [fixup] Don't create a call to memset and move test to
 vscale-memset.ll

---
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp | 67 ++++++++++---------
 .../CodeGen/AArch64/memset-scalable-size.ll   | 56 ----------------
 .../Transforms/MemCpyOpt/vscale-memset.ll     | 14 +++-
 3 files changed, 48 insertions(+), 89 deletions(-)
 delete mode 100644 llvm/test/CodeGen/AArch64/memset-scalable-size.ll

diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index fc5f6ff2b7f377..bb98b3d1c07259 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -787,44 +787,47 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
   // Ensure that the value being stored is something that can be memset'able a
   // byte at a time like "0" or "-1" or any width, as well as things like
   // 0xA0A0A0A0 and 0.0.
-  auto *V = SI->getOperand(0);
-  if (Value *ByteVal = isBytewiseValue(V, DL)) {
-    if (Instruction *I =
-            tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) {
-      BBI = I->getIterator(); // Don't invalidate iterator.
-      return true;
-    }
+  Value *V = SI->getOperand(0);
+  Value *ByteVal = isBytewiseValue(V, DL);
+  if (!ByteVal)
+    return false;
 
-    // If we have an aggregate, we try to promote it to memset regardless
-    // of opportunity for merging as it can expose optimization opportunities
-    // in subsequent passes.
-    auto *T = V->getType();
-    if (T->isAggregateType()) {
-      IRBuilder<> Builder(SI);
-      Value *Size =
-          Builder.CreateTypeSize(Builder.getInt64Ty(), DL.getTypeStoreSize(T));
-      auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size,
-                                     SI->getAlign());
-      M->copyMetadata(*SI, LLVMContext::MD_DIAssignID);
+  if (Instruction *I =
+          tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) {
+    BBI = I->getIterator(); // Don't invalidate iterator.
+    return true;
+  }
+
+  // If we have an aggregate, we try to promote it to memset regardless
+  // of opportunity for merging as it can expose optimization opportunities
+  // in subsequent passes.
+  auto *T = V->getType();
+  if (!T->isAggregateType())
+    return false;
 
-      LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n");
+  TypeSize Size = DL.getTypeStoreSize(T);
+  if (Size.isScalable())
+    return false;
 
-      // The newly inserted memset is immediately overwritten by the original
-      // store, so we do not need to rename uses.
-      auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI));
-      auto *NewAccess = MSSAU->createMemoryAccessBefore(M, nullptr, StoreDef);
-      MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false);
+  IRBuilder<> Builder(SI);
+  auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size,
+                                 SI->getAlign());
+  M->copyMetadata(*SI, LLVMContext::MD_DIAssignID);
 
-      eraseInstruction(SI);
-      NumMemSetInfer++;
+  LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n");
 
-      // Make sure we do not invalidate the iterator.
-      BBI = M->getIterator();
-      return true;
-    }
-  }
+  // The newly inserted memset is immediately overwritten by the original
+  // store, so we do not need to rename uses.
+  auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI));
+  auto *NewAccess = MSSAU->createMemoryAccessBefore(M, nullptr, StoreDef);
+  MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false);
 
-  return false;
+  eraseInstruction(SI);
+  NumMemSetInfer++;
+
+  // Make sure we do not invalidate the iterator.
+  BBI = M->getIterator();
+  return true;
 }
 
 bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
diff --git a/llvm/test/CodeGen/AArch64/memset-scalable-size.ll b/llvm/test/CodeGen/AArch64/memset-scalable-size.ll
deleted file mode 100644
index 8ea6330f235a69..00000000000000
--- a/llvm/test/CodeGen/AArch64/memset-scalable-size.ll
+++ /dev/null
@@ -1,56 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt -S --passes=memcpyopt < %s | FileCheck %s
-target triple = "aarch64-unknown-linux"
-
-define void @f0() {
-; CHECK-LABEL: define void @f0() {
-; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 16 x i1>, <vscale x 16 x i1> }, align 2
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 2 [[P]], i8 0, i64 [[TMP1]], i1 false)
-; CHECK-NEXT:    call void @g(ptr [[P]])
-; CHECK-NEXT:    ret void
-;
-entry:
-  %p = alloca { <vscale x 16 x i1>, <vscale x 16 x i1>}, align 2
-  store { <vscale x 16 x i1>, <vscale x 16 x i1> } zeroinitializer, ptr %p, align 2
-  call void @g(ptr %p)
-  ret void
-}
-
-define void @f1() {
-; CHECK-LABEL: define void @f1() {
-; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> }, align 16
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 48
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 16 [[P]], i8 0, i64 [[TMP1]], i1 false)
-; CHECK-NEXT:    call void @g(ptr [[P]])
-; CHECK-NEXT:    ret void
-;
-entry:
-  %p = alloca {<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> }, align 16
-  store {<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } zeroinitializer, ptr %p, align 16
-  call void @g(ptr %p)
-  ret void
-}
-
-define void @f2() {
-; CHECK-LABEL: define void @f2() {
-; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[P:%.*]] = alloca { <vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> }, align 16
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 192
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 16 [[P]], i8 0, i64 [[TMP1]], i1 false)
-; CHECK-NEXT:    call void @g(ptr [[P]])
-; CHECK-NEXT:    ret void
-;
-entry:
-  %p = alloca {<vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> }, align 16
-  store {<vscale x 8 x double>, <vscale x 8 x double>, <vscale x 8 x double> } zeroinitializer, ptr %p, align 16
-  call void @g(ptr %p)
-  ret void
-}
-
-declare void @g(ptr)
diff --git a/llvm/test/Transforms/MemCpyOpt/vscale-memset.ll b/llvm/test/Transforms/MemCpyOpt/vscale-memset.ll
index b4ab443fdfb68c..45de52065cd5c1 100644
--- a/llvm/test/Transforms/MemCpyOpt/vscale-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/vscale-memset.ll
@@ -8,7 +8,7 @@
 define void @foo(ptr %p) {
 ; CHECK-LABEL: @foo(
 ; CHECK-NEXT:    store <vscale x 16 x i8> zeroinitializer, ptr [[P:%.*]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr <vscale x 16 x i8>, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr <vscale x 16 x i8>, ptr [[P]], i64 1
 ; CHECK-NEXT:    store <vscale x 16 x i8> zeroinitializer, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret void
 ;
@@ -18,6 +18,18 @@ define void @foo(ptr %p) {
   ret void
 }
 
+; Test the compiler does not crash on a store of a scalable aggregate type.
+define void @test_no_crash_scalable_agg(ptr %p) {
+; CHECK-LABEL: @test_no_crash_scalable_agg(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store { <vscale x 16 x i1>, <vscale x 16 x i1> } zeroinitializer, ptr [[P:%.*]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  store { <vscale x 16 x i1>, <vscale x 16 x i1> } zeroinitializer, ptr %p, align 2
+  ret void
+}
+
 ; Positive test
 
 define void @memset_vscale_index_zero(ptr %p, i8 %z) {



More information about the llvm-commits mailing list