[llvm] [IR][TBAA] Allow multiple fileds with same offset in TBAA struct-path (PR #76356)

Bushev Dmitry via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 31 09:31:00 PST 2024


https://github.com/dybv-sc updated https://github.com/llvm/llvm-project/pull/76356

>From ae3c9059df82b1845a91ed06319202d903d1a242 Mon Sep 17 00:00:00 2001
From: Dmitry Bushev <dmitry.bushev at syntacore.com>
Date: Mon, 25 Dec 2023 13:16:43 +0300
Subject: [PATCH] [IR][TBAA] Allow multiple fileds with same offset in TBAA
 struct-path

Support for multiple fields to have same offset in TBAA struct-path
metadata nodes. Primary goal is to support union-like structures
to participate in TBAA struct-path resolution.
---
 llvm/docs/LangRef.rst                         |  17 ++-
 llvm/include/llvm/IR/Verifier.h               |  11 +-
 llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp  |  57 ++++---
 llvm/lib/IR/Verifier.cpp                      | 143 ++++++++++++------
 .../TypeBasedAliasAnalysis/aggregates.ll      |  20 +++
 llvm/test/Verifier/tbaa.ll                    |  10 +-
 6 files changed, 177 insertions(+), 81 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 7a7ddc59ba985..4882256d30d38 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -6424,9 +6424,10 @@ tuples this way:
    undefined if ``Offset`` is non-zero.
 
  * If ``BaseTy`` is a struct type then ``ImmediateParent(BaseTy, Offset)``
-   is ``(NewTy, NewOffset)`` where ``NewTy`` is the type contained in
-   ``BaseTy`` at offset ``Offset`` and ``NewOffset`` is ``Offset`` adjusted
-   to be relative within that inner type.
+   is array of ``(NewTy[N], NewOffset)`` where ``NewTy[N]`` is the Nth type
+   contained in ``BaseTy`` at offset ``Offset`` and ``NewOffset`` is
+   ``Offset`` adjusted to be relative within that inner type. Multiple types
+   occupying same offset allow to describe union-like structures.
 
 A memory access with an access tag ``(BaseTy1, AccessTy1, Offset1)``
 aliases a memory access with an access tag ``(BaseTy2, AccessTy2,
@@ -6437,9 +6438,9 @@ As a concrete example, the type descriptor graph for the following program
 
 .. code-block:: c
 
-    struct Inner {
+    union Inner {
       int i;    // offset 0
-      float f;  // offset 4
+      float f;  // offset 0
     };
 
     struct Outer {
@@ -6451,7 +6452,7 @@ As a concrete example, the type descriptor graph for the following program
     void f(struct Outer* outer, struct Inner* inner, float* f, int* i, char* c) {
       outer->f = 0;            // tag0: (OuterStructTy, FloatScalarTy, 0)
       outer->inner_a.i = 0;    // tag1: (OuterStructTy, IntScalarTy, 12)
-      outer->inner_a.f = 0.0;  // tag2: (OuterStructTy, FloatScalarTy, 16)
+      outer->inner_a.f = 0.0;  // tag2: (OuterStructTy, FloatScalarTy, 12)
       *f = 0.0;                // tag3: (FloatScalarTy, FloatScalarTy, 0)
     }
 
@@ -6465,13 +6466,13 @@ type):
     FloatScalarTy = ("float", CharScalarTy, 0)
     DoubleScalarTy = ("double", CharScalarTy, 0)
     IntScalarTy = ("int", CharScalarTy, 0)
-    InnerStructTy = {"Inner" (IntScalarTy, 0), (FloatScalarTy, 4)}
+    InnerStructTy = {"Inner" (IntScalarTy, 0), (FloatScalarTy, 0)}
     OuterStructTy = {"Outer", (FloatScalarTy, 0), (DoubleScalarTy, 4),
                      (InnerStructTy, 12)}
 
 
 with (e.g.) ``ImmediateParent(OuterStructTy, 12)`` = ``(InnerStructTy,
-0)``, ``ImmediateParent(InnerStructTy, 0)`` = ``(IntScalarTy, 0)``, and
+0)``, ``ImmediateParent(InnerStructTy, 0)`` = ``(IntScalarTy, 0), (FloatScalarTy, 0)``, and
 ``ImmediateParent(IntScalarTy, 0)`` = ``(CharScalarTy, 0)``.
 
 .. _tbaa_node_representation:
diff --git a/llvm/include/llvm/IR/Verifier.h b/llvm/include/llvm/IR/Verifier.h
index b25f8eb77ee38..95db2c4b16eca 100644
--- a/llvm/include/llvm/IR/Verifier.h
+++ b/llvm/include/llvm/IR/Verifier.h
@@ -59,8 +59,15 @@ class TBAAVerifier {
 
   /// \name Helper functions used by \c visitTBAAMetadata.
   /// @{
-  MDNode *getFieldNodeFromTBAABaseNode(Instruction &I, const MDNode *BaseNode,
-                                       APInt &Offset, bool IsNewFormat);
+  std::vector<MDNode *> getFieldNodeFromTBAABaseNode(Instruction &I,
+                                                     const MDNode *BaseNode,
+                                                     APInt &Offset,
+                                                     bool IsNewFormat);
+  bool findAccessTypeNode(Instruction &I,
+                          SmallPtrSetImpl<const MDNode *> &StructPath,
+                          APInt Offset, bool IsNewFormat,
+                          const MDNode *AccessType, const MDNode *BaseNode,
+                          const MDNode *MD);
   TBAAVerifier::TBAABaseNodeSummary verifyTBAABaseNode(Instruction &I,
                                                        const MDNode *BaseNode,
                                                        bool IsNewFormat);
diff --git a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp
index e4dc1a867f6f0..c0ab88d85a85b 100644
--- a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp
@@ -121,6 +121,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include <cassert>
 #include <cstdint>
+#include <stack>
 
 using namespace llvm;
 
@@ -299,9 +300,10 @@ class TBAAStructTypeNode {
     return TBAAStructTypeNode(TypeNode);
   }
 
-  /// Get this TBAAStructTypeNode's field in the type DAG with
+  /// Get this TBAAStructTypeNode's fields in the type DAG with
   /// given offset. Update the offset to be relative to the field type.
-  TBAAStructTypeNode getField(uint64_t &Offset) const {
+  /// There could be multiple fields with same offset.
+  std::vector<TBAAStructTypeNode> getField(uint64_t &Offset) const {
     bool NewFormat = isNewFormat();
     const ArrayRef<MDOperand> Operands = Node->operands();
     const unsigned NumOperands = Operands.size();
@@ -309,11 +311,11 @@ class TBAAStructTypeNode {
     if (NewFormat) {
       // New-format root and scalar type nodes have no fields.
       if (NumOperands < 6)
-        return TBAAStructTypeNode();
+        return {TBAAStructTypeNode()};
     } else {
       // Parent can be omitted for the root node.
       if (NumOperands < 2)
-        return TBAAStructTypeNode();
+        return {TBAAStructTypeNode()};
 
       // Fast path for a scalar type node and a struct type node with a single
       // field.
@@ -325,8 +327,8 @@ class TBAAStructTypeNode {
         Offset -= Cur;
         MDNode *P = dyn_cast_or_null<MDNode>(Operands[1]);
         if (!P)
-          return TBAAStructTypeNode();
-        return TBAAStructTypeNode(P);
+          return {TBAAStructTypeNode()};
+        return {TBAAStructTypeNode(P)};
       }
     }
 
@@ -336,6 +338,8 @@ class TBAAStructTypeNode {
     unsigned NumOpsPerField = NewFormat ? 3 : 2;
     unsigned TheIdx = 0;
 
+    std::vector<TBAAStructTypeNode> Ret;
+
     for (unsigned Idx = FirstFieldOpNo; Idx < NumOperands;
          Idx += NumOpsPerField) {
       uint64_t Cur =
@@ -353,10 +357,20 @@ class TBAAStructTypeNode {
     uint64_t Cur =
         mdconst::extract<ConstantInt>(Operands[TheIdx + 1])->getZExtValue();
     Offset -= Cur;
+
+    // Collect all fields that have right offset.
     MDNode *P = dyn_cast_or_null<MDNode>(Operands[TheIdx]);
-    if (!P)
-      return TBAAStructTypeNode();
-    return TBAAStructTypeNode(P);
+    Ret.emplace_back(P ? TBAAStructTypeNode(P) : TBAAStructTypeNode());
+
+    while (TheIdx > FirstFieldOpNo) {
+      TheIdx -= NumOpsPerField;
+      auto Val = mdconst::extract<ConstantInt>(Operands[TheIdx + 1]);
+      if (Cur != Val->getZExtValue())
+        break;
+      MDNode *P = dyn_cast_or_null<MDNode>(Operands[TheIdx]);
+      P ? Ret.emplace_back(P) : Ret.emplace_back();
+    }
+    return Ret;
   }
 };
 
@@ -599,17 +613,24 @@ static bool mayBeAccessToSubobjectOf(TBAAStructTagNode BaseTag,
   // from the base type, follow the edge with the correct offset in the type DAG
   // and adjust the offset until we reach the field type or until we reach the
   // access type.
+  // If multiple fields have same offset in some base type, then scan each such
+  // field.
   bool NewFormat = BaseTag.isNewFormat();
   TBAAStructTypeNode BaseType(BaseTag.getBaseType());
   uint64_t OffsetInBase = BaseTag.getOffset();
 
-  for (;;) {
-    // In the old format there is no distinction between fields and parent
-    // types, so in this case we consider all nodes up to the root.
-    if (!BaseType.getNode()) {
-      assert(!NewFormat && "Did not see access type in access path!");
-      break;
-    }
+  SmallVector<std::pair<TBAAStructTypeNode, uint64_t>, 4> ToCheck;
+  ToCheck.emplace_back(BaseType, OffsetInBase);
+  while (!ToCheck.empty()) {
+    std::tie(BaseType, OffsetInBase) = ToCheck.back();
+    ToCheck.pop_back();
+
+    // In case if root is reached, still check the remaining candidates.
+    // For new format it is always expected for access type to be found.
+    // For old format all nodes up to the root are considered from all
+    // candidates.
+    if (!BaseType.getNode())
+      continue;
 
     if (BaseType.getNode() == SubobjectTag.getBaseType()) {
       bool SameMemberAccess = OffsetInBase == SubobjectTag.getOffset();
@@ -627,13 +648,15 @@ static bool mayBeAccessToSubobjectOf(TBAAStructTagNode BaseTag,
 
     // Follow the edge with the correct offset. Offset will be adjusted to
     // be relative to the field type.
-    BaseType = BaseType.getField(OffsetInBase);
+    for (auto &&F : BaseType.getField(OffsetInBase))
+      ToCheck.emplace_back(F, OffsetInBase);
   }
 
   // If the base object has a direct or indirect field of the subobject's type,
   // then this may be an access to that field. We need this to check now that
   // we support aggregates as access types.
   if (NewFormat) {
+    assert(!NewFormat && "Did not see access type in access path!");
     // TBAAStructTypeNode BaseAccessType(BaseTag.getAccessType());
     TBAAStructTypeNode FieldType(SubobjectTag.getBaseType());
     if (hasField(BaseType, FieldType)) {
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 91cf91fbc788b..ae3a88761d8c7 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6973,48 +6973,56 @@ bool TBAAVerifier::isValidScalarTBAANode(const MDNode *MD) {
   return Result;
 }
 
-/// Returns the field node at the offset \p Offset in \p BaseNode.  Update \p
-/// Offset in place to be the offset within the field node returned.
+/// Returns one or several field nodes at the offset \p Offset in \p BaseNode.
+/// Returns empty vector if \p BaseNode has no fields with specified offset.
+/// Update \p Offset in place to be the offset within the field node returned.
 ///
 /// We assume we've okayed \p BaseNode via \c verifyTBAABaseNode.
-MDNode *TBAAVerifier::getFieldNodeFromTBAABaseNode(Instruction &I,
-                                                   const MDNode *BaseNode,
-                                                   APInt &Offset,
-                                                   bool IsNewFormat) {
+std::vector<MDNode *> TBAAVerifier::getFieldNodeFromTBAABaseNode(
+    Instruction &I, const MDNode *BaseNode, APInt &Offset, bool IsNewFormat) {
   assert(BaseNode->getNumOperands() >= 2 && "Invalid base node!");
 
   // Scalar nodes have only one possible "field" -- their parent in the access
   // hierarchy.  Offset must be zero at this point, but our caller is supposed
   // to check that.
   if (BaseNode->getNumOperands() == 2)
-    return cast<MDNode>(BaseNode->getOperand(1));
+    return {cast<MDNode>(BaseNode->getOperand(1))};
 
   unsigned FirstFieldOpNo = IsNewFormat ? 3 : 1;
   unsigned NumOpsPerField = IsNewFormat ? 3 : 2;
+
+  unsigned LastIdx = BaseNode->getNumOperands() - NumOpsPerField;
   for (unsigned Idx = FirstFieldOpNo; Idx < BaseNode->getNumOperands();
            Idx += NumOpsPerField) {
     auto *OffsetEntryCI =
         mdconst::extract<ConstantInt>(BaseNode->getOperand(Idx + 1));
     if (OffsetEntryCI->getValue().ugt(Offset)) {
       if (Idx == FirstFieldOpNo) {
-        CheckFailed("Could not find TBAA parent in struct type node", &I,
-                    BaseNode, &Offset);
-        return nullptr;
+        return {};
       }
 
-      unsigned PrevIdx = Idx - NumOpsPerField;
-      auto *PrevOffsetEntryCI =
-          mdconst::extract<ConstantInt>(BaseNode->getOperand(PrevIdx + 1));
-      Offset -= PrevOffsetEntryCI->getValue();
-      return cast<MDNode>(BaseNode->getOperand(PrevIdx));
+      LastIdx = Idx - NumOpsPerField;
+      break;
     }
   }
 
-  unsigned LastIdx = BaseNode->getNumOperands() - NumOpsPerField;
   auto *LastOffsetEntryCI = mdconst::extract<ConstantInt>(
       BaseNode->getOperand(LastIdx + 1));
-  Offset -= LastOffsetEntryCI->getValue();
-  return cast<MDNode>(BaseNode->getOperand(LastIdx));
+  auto LastOffsetVal = LastOffsetEntryCI->getValue();
+  Offset -= LastOffsetVal;
+
+  std::vector<MDNode *> Ret;
+  Ret.emplace_back(cast<MDNode>(BaseNode->getOperand(LastIdx)));
+  while (LastIdx > FirstFieldOpNo) {
+    LastIdx -= NumOpsPerField;
+    LastOffsetEntryCI =
+        mdconst::extract<ConstantInt>(BaseNode->getOperand(LastIdx + 1));
+    if (LastOffsetEntryCI->getValue() != LastOffsetVal)
+      break;
+    Ret.emplace_back(cast<MDNode>(BaseNode->getOperand(LastIdx)));
+  }
+
+  return Ret;
 }
 
 static bool isNewFormatTBAATypeNode(llvm::MDNode *Type) {
@@ -7091,47 +7099,84 @@ bool TBAAVerifier::visitTBAAMetadata(Instruction &I, const MDNode *MD) {
   CheckTBAA(OffsetCI, "Offset must be constant integer", &I, MD);
 
   APInt Offset = OffsetCI->getValue();
-  bool SeenAccessTypeInPath = false;
 
-  SmallPtrSet<MDNode *, 4> StructPath;
+  SmallPtrSet<const MDNode *, 4> StructPath;
 
-  for (/* empty */; BaseNode && !IsRootTBAANode(BaseNode);
-       BaseNode = getFieldNodeFromTBAABaseNode(I, BaseNode, Offset,
-                                               IsNewFormat)) {
-    if (!StructPath.insert(BaseNode).second) {
-      CheckFailed("Cycle detected in struct path", &I, MD);
-      return false;
-    }
+  auto &&[Invalid, BaseNodeBitWidth] =
+      verifyTBAABaseNode(I, BaseNode, IsNewFormat);
 
-    bool Invalid;
-    unsigned BaseNodeBitWidth;
-    std::tie(Invalid, BaseNodeBitWidth) = verifyTBAABaseNode(I, BaseNode,
-                                                             IsNewFormat);
+  // If the base node is invalid in itself, then we've already printed all the
+  // errors we wanted to print.
+  if (Invalid)
+    return false;
 
-    // If the base node is invalid in itself, then we've already printed all the
-    // errors we wanted to print.
-    if (Invalid)
-      return false;
+  bool SeenAccessTypeInPath = BaseNode == AccessType;
+  if (SeenAccessTypeInPath) {
+    CheckTBAA(Offset == 0, "Offset not zero at the point of scalar access", &I,
+              MD, &Offset);
+    if (IsNewFormat)
+      return true;
+  }
 
-    SeenAccessTypeInPath |= BaseNode == AccessType;
+  CheckTBAA(findAccessTypeNode(I, StructPath, Offset, IsNewFormat, AccessType,
+                               BaseNode, MD) ||
+                SeenAccessTypeInPath,
+            "Did not see access type in access path!", &I, MD);
+  return true;
+}
 
-    if (isValidScalarTBAANode(BaseNode) || BaseNode == AccessType)
-      CheckTBAA(Offset == 0, "Offset not zero at the point of scalar access",
-                &I, MD, &Offset);
+bool TBAAVerifier::findAccessTypeNode(
+    Instruction &I, SmallPtrSetImpl<const MDNode *> &StructPath, APInt Offset,
+    bool IsNewFormat, const MDNode *AccessType, const MDNode *BaseNode,
+    const MDNode *MD) {
+  if (!BaseNode || IsRootTBAANode(BaseNode))
+    return false;
 
-    CheckTBAA(BaseNodeBitWidth == Offset.getBitWidth() ||
-                  (BaseNodeBitWidth == 0 && Offset == 0) ||
-                  (IsNewFormat && BaseNodeBitWidth == ~0u),
-              "Access bit-width not the same as description bit-width", &I, MD,
-              BaseNodeBitWidth, Offset.getBitWidth());
+  auto &&[Invalid, BaseNodeBitWidth] =
+      verifyTBAABaseNode(I, BaseNode, IsNewFormat);
 
-    if (IsNewFormat && SeenAccessTypeInPath)
-      break;
+  // If the base node is invalid in itself, then we've already printed all the
+  // errors we wanted to print.
+  if (Invalid)
+    return false;
+
+  // Offset at point of scalar access must be zero. Skip mismatched nodes.
+  if ((isValidScalarTBAANode(BaseNode) || BaseNode == AccessType) &&
+      Offset != 0)
+    return false;
+
+  CheckTBAA(BaseNodeBitWidth == Offset.getBitWidth() ||
+                (BaseNodeBitWidth == 0 && Offset == 0) ||
+                (IsNewFormat && BaseNodeBitWidth == ~0u),
+            "Access bit-width not the same as description bit-width", &I, MD,
+            BaseNodeBitWidth, Offset.getBitWidth());
+
+  bool SeenAccessTypeInPath = (BaseNode == AccessType && Offset == 0);
+
+  if (IsNewFormat && SeenAccessTypeInPath)
+    return true;
+
+  auto ProbableNodes =
+      getFieldNodeFromTBAABaseNode(I, BaseNode, Offset, IsNewFormat);
+
+  if (!StructPath.insert(BaseNode).second) {
+    CheckFailed("Cycle detected in struct path", &I, MD);
+    return false;
   }
 
-  CheckTBAA(SeenAccessTypeInPath, "Did not see access type in access path!", &I,
-            MD);
-  return true;
+  for (auto *PN : ProbableNodes) {
+    if (!PN || IsRootTBAANode(PN))
+      continue;
+
+    SmallPtrSet<const MDNode *, 4> StructPathCopy;
+    StructPathCopy.insert(StructPath.begin(), StructPath.end());
+
+    if (findAccessTypeNode(I, StructPathCopy, Offset, IsNewFormat, AccessType,
+                           PN, MD))
+      return true;
+  }
+
+  return SeenAccessTypeInPath;
 }
 
 char VerifierLegacyPass::ID = 0;
diff --git a/llvm/test/Analysis/TypeBasedAliasAnalysis/aggregates.ll b/llvm/test/Analysis/TypeBasedAliasAnalysis/aggregates.ll
index 4049c78049e03..422f8d8040468 100644
--- a/llvm/test/Analysis/TypeBasedAliasAnalysis/aggregates.ll
+++ b/llvm/test/Analysis/TypeBasedAliasAnalysis/aggregates.ll
@@ -105,6 +105,22 @@ entry:
   ret i32 %0
 }
 
+; C vs. D  =>  MayAlias.
+define i32 @f7(ptr %c, ptr %d) {
+entry:
+; CHECK-LABEL: f7
+; CHECK: MayAlias: store i16 7, {{.*}} <-> store i32 5,
+; OPT-LABEL: f7
+; OPT: store i32 5,
+; OPT: store i16 7,
+; OPT: load i32
+; OPT: ret i32
+  store i32 5, ptr %c, align 4, !tbaa !18  ; TAG_Union_int
+  store i16 7, ptr %d, align 4, !tbaa !17  ; TAG_Union_short
+  %0 = load i32, ptr %c, align 4, !tbaa !18  ; TAG_Union_int
+  ret i32 %0
+}
+
 !0 = !{!"root"}
 !1 = !{!0, i64 1, !"char"}
 !2 = !{!1, i64 4, !"int"}
@@ -128,3 +144,7 @@ entry:
 
 !14 = !{!4, i64 2, !"D", !11, i64 0, i64 2}
 !15 = !{!14, !14, i64 0, i64 2}  ; TAG_D
+
+!16 = !{!1, i64 2, !"Union", !11, i64 0, i64 2, !2, i64 0, i64 4}
+!17 = !{!16, !11, i64 0, i64 2}  ; TAG_Union_short
+!18 = !{!16, !2, i64 0, i64 4}  ; TAG_Union_int
diff --git a/llvm/test/Verifier/tbaa.ll b/llvm/test/Verifier/tbaa.ll
index abaa415aed749..107192542d55d 100644
--- a/llvm/test/Verifier/tbaa.ll
+++ b/llvm/test/Verifier/tbaa.ll
@@ -61,15 +61,15 @@ define void @f_1(ptr %ptr) {
 ; CHECK: Cycle detected in struct path
 ; CHECK-NEXT:  store i32 0, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
 
-; CHECK: Offset not zero at the point of scalar access
+; CHECK: Did not see access type in access path
+; CHECK-NEXT:  store i32 0, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
+
+; CHECK: Did not see access type in access path
 ; CHECK-NEXT:  store i32 1, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
 
-; CHECK: Offset not zero at the point of scalar access
+; CHECK: Did not see access type in access path
 ; CHECK-NEXT:  store i32 2, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
 
-; CHECK: Could not find TBAA parent in struct type node
-; CHECK-NEXT:  store i32 3, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
-
 ; CHECK: Did not see access type in access path!
 ; CHECK-NEXT:  store i32 3, ptr %ptr, align 4, !tbaa !{{[0-9]+}}
 



More information about the llvm-commits mailing list