[llvm] [SROA] Allow rewriting memcpy depending on tbaa.struct (PR #77597)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 04:54:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Björn Pettersson (bjope)

<details>
<summary>Changes</summary>

The bugfix in commit 54067c5fbe, related to
  https://github.com/llvm/llvm-project/issues/64081
limited the ability of SROA to handle non byte-sized types
when used in aggregates that are memcpy'd.

Main problem was that the LLVM types used in an alloca doesn't
always reflect if the stack slot can be used to for multiple
types (typically unions). So even if we for example have
  %p = alloca i6
that stack slot may be used for other types than i6. And it
would be legal to for example store an i8 to that stack slot.

Thus, if %p was dereferenced in a memcpy we needed to consider
that also padding bits (seen from the i6 perspective) could be
of importance.

The limitation added to SROA in commit https://github.com/llvm/llvm-project/commit/54067c5fbe9fc13ab195cdddb8f17e18d72b5fe4 resulted
in huge regressions for a downstream target. Since the frontend
typically emit memcpy for aggregate copy it seems quite normal
that one end up with a memcpy that is copying padding bits even
when there are no unions or type punning. So that limitation
seem unfortunate in general.

In this patch we try to lift the restrictions a bit. If the
memcpy is decorated with tbaa.struct metadata we look at that
metadata. If we find out if the slice used by our new alloca is
touching memory described by a single scalar type according to
the tbaa.struct, then we assume that the type derived for the
new alloca is correct for all accesses made to the stack slot.
And then we can allow replacing the memcpy by regular load/store
instructions operating on that type (disregarding any padding
bits).

---
Full diff: https://github.com/llvm/llvm-project/pull/77597.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/SROA.cpp (+36-2) 
- (modified) llvm/test/Transforms/SROA/pr64081.ll (+74) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 75cddfa16d6db5..4439765f68db0b 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -3276,6 +3276,39 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
     // memmove with memcpy, and we don't need to worry about all manner of
     // downsides to splitting and transforming the operations.
 
+    // The tbaa.struct is only being explicit about byte padding. Here we assume
+    // that if the derived type used for the NewAI maps to a single scalar type,
+    // as given by the tbaa.struct, then it is safe to assume that we can use
+    // that type when doing the copying even if it include bit padding. If there
+    // for example would be a union of "_BitInt(3)" and "char" types the
+    // tbaa.struct would have multiple entries indicating the different types
+    // (or there wouldn't be any tbaa.struct)..
+    auto IsSingleTypeAccordingToTBAA = [&]() -> bool {
+      // Only consider the case when we have a tbaa.struct.
+      if (!(AATags && AATags.TBAAStruct))
+        return false;
+      MDNode *MD = AATags.TBAAStruct;
+      uint64_t Offset = NewBeginOffset - BeginOffset;
+      unsigned Count = 0;
+      for (size_t i = 0, size = MD->getNumOperands(); i < size; i += 3) {
+        uint64_t InnerOffset =
+          mdconst::extract<ConstantInt>(MD->getOperand(i))->getZExtValue();
+        uint64_t InnerSize =
+          mdconst::extract<ConstantInt>(MD->getOperand(i + 1))->getZExtValue();
+        // Ignore entries that aren't overlapping with our slice.
+        if (InnerOffset + InnerSize <= Offset ||
+            InnerOffset >= Offset + SliceSize)
+          continue;
+        // Only allow a single match (no unions).
+        if (++Count > 1)
+          return false;
+        // Size/offset must match up.
+        if (InnerSize != SliceSize || Offset != InnerOffset)
+          return false;
+      }
+      return Count == 1;
+    };
+
     // If this doesn't map cleanly onto the alloca type, and that type isn't
     // a single value type, just emit a memcpy.
     bool EmitMemCpy =
@@ -3283,8 +3316,9 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
         (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset ||
          SliceSize !=
              DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedValue() ||
-         !DL.typeSizeEqualsStoreSize(NewAI.getAllocatedType()) ||
-         !NewAI.getAllocatedType()->isSingleValueType());
+         !NewAI.getAllocatedType()->isSingleValueType() ||
+         (!DL.typeSizeEqualsStoreSize(NewAI.getAllocatedType()) &&
+          !IsSingleTypeAccordingToTBAA()));
 
     // If we're just going to emit a memcpy, the alloca hasn't changed, and the
     // size hasn't been shrunk based on analysis of the viable range, this is
diff --git a/llvm/test/Transforms/SROA/pr64081.ll b/llvm/test/Transforms/SROA/pr64081.ll
index 4b893842138263..ba83e495f56c27 100644
--- a/llvm/test/Transforms/SROA/pr64081.ll
+++ b/llvm/test/Transforms/SROA/pr64081.ll
@@ -30,3 +30,77 @@ bb:
 declare void @use(ptr)
 
 declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+
+
+; No unions or overlaps in the tbaa.struct. So we can rely on the types
+define void @test2(i3 %x) {
+; CHECK-LABEL: define void @test2(
+; CHECK-SAME: i3 [[X:%.*]]) {
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[RES:%.*]] = alloca [[B:%.*]], align 8
+; CHECK-NEXT:    store i1 true, ptr [[RES]], align 1, !tbaa.struct [[TBAA_STRUCT0:![0-9]+]]
+; CHECK-NEXT:    [[TMP_SROA_2_0_RES_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[RES]], i64 1
+; CHECK-NEXT:    store i3 [[X]], ptr [[TMP_SROA_2_0_RES_SROA_IDX]], align 1, !tbaa.struct [[TBAA_STRUCT7:![0-9]+]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @use(ptr [[RES]])
+; CHECK-NEXT:    ret void
+;
+bb:
+  %res = alloca %B
+  %tmp = alloca %B
+  %tmp.1 = getelementptr i8, ptr %tmp, i64 1
+  store i1 1, ptr %tmp
+  store i3 %x, ptr %tmp.1
+  call void @llvm.memcpy.p0.p0.i64(ptr %res, ptr %tmp, i64 2, i1 false), !tbaa.struct !6
+  call i8 @use(ptr %res)
+  ret void
+}
+
+; Union preventing SROA from removing the memcpy for the first byte.
+define void @test3(i3 %x) {
+; CHECK-LABEL: define void @test3(
+; CHECK-SAME: i3 [[X:%.*]]) {
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[RES:%.*]] = alloca [[B:%.*]], align 8
+; CHECK-NEXT:    [[TMP_SROA_0:%.*]] = alloca i1, align 8
+; CHECK-NEXT:    store i1 true, ptr [[TMP_SROA_0]], align 8
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 1 [[RES]], ptr align 8 [[TMP_SROA_0]], i64 1, i1 false), !tbaa.struct [[TBAA_STRUCT8:![0-9]+]]
+; CHECK-NEXT:    [[TMP_SROA_2_0_RES_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[RES]], i64 1
+; CHECK-NEXT:    store i3 [[X]], ptr [[TMP_SROA_2_0_RES_SROA_IDX]], align 1, !tbaa.struct [[TBAA_STRUCT7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @use(ptr [[RES]])
+; CHECK-NEXT:    ret void
+;
+bb:
+  %res = alloca %B
+  %tmp = alloca %B
+  %tmp.1 = getelementptr i8, ptr %tmp, i64 1
+  store i1 1, ptr %tmp
+  store i3 %x, ptr %tmp.1
+  call void @llvm.memcpy.p0.p0.i64(ptr %res, ptr %tmp, i64 2, i1 false), !tbaa.struct !9
+  call i8 @use(ptr %res)
+  ret void
+}
+
+!1 = !{!"_BitInt(7)", !4, i64 0}
+!2 = !{!"_BitInt(1)", !4, i64 0}
+!3 = !{!"_BitInt(3)", !4, i64 0}
+!4 = !{!"omnipotent char", !5, i64 0}
+!5 = !{!"Simple C++ TBAA"}
+!6 = !{i64 0, i64 1, !7, i64 1, i64 1, !8}
+!7 = !{!2, !2, i64 0}
+!8 = !{!3, !3, i64 0}
+!9 = !{i64 0, i64 1, !10, i64 0, i64 1, !7, i64 1, i64 1, !8}
+!10 = !{!1, !1, i64 0}
+
+;.
+; CHECK: [[TBAA_STRUCT0]] = !{i64 0, i64 1, [[META1:![0-9]+]], i64 1, i64 1, [[META5:![0-9]+]]}
+; CHECK: [[META1]] = !{[[META2:![0-9]+]], [[META2]], i64 0}
+; CHECK: [[META2]] = !{!"_BitInt(1)", [[META3:![0-9]+]], i64 0}
+; CHECK: [[META3]] = !{!"omnipotent char", [[META4:![0-9]+]], i64 0}
+; CHECK: [[META4]] = !{!"Simple C++ TBAA"}
+; CHECK: [[META5]] = !{[[META6:![0-9]+]], [[META6]], i64 0}
+; CHECK: [[META6]] = !{!"_BitInt(3)", [[META3]], i64 0}
+; CHECK: [[TBAA_STRUCT7]] = !{i64 0, i64 1, [[META5]]}
+; CHECK: [[TBAA_STRUCT8]] = !{i64 0, i64 1, [[META9:![0-9]+]], i64 0, i64 1, [[META1]], i64 1, i64 1, [[META5]]}
+; CHECK: [[META9]] = !{[[META10:![0-9]+]], [[META10]], i64 0}
+; CHECK: [[META10]] = !{!"_BitInt(7)", [[META3]], i64 0}
+;.

``````````

</details>


https://github.com/llvm/llvm-project/pull/77597


More information about the llvm-commits mailing list