[clang] [Clang] Implement the 'counted_by' attribute (PR #76348)

Yeoul Na via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 8 15:13:16 PST 2024


================
@@ -944,22 +951,259 @@ static llvm::Value *getArrayIndexingBound(CodeGenFunction &CGF,
   return nullptr;
 }
 
+namespace {
+
+/// \p StructAccessBase returns the base \p Expr of a field access. It returns
+/// either a \p DeclRefExpr, representing the base pointer to the struct, i.e.:
+///
+///     p in p-> a.b.c
+///
+/// or a \p MemberExpr, if the \p MemberExpr has the \p RecordDecl we're
+/// looking for:
+///
+///     struct s {
+///       struct s *ptr;
+///       int count;
+///       char array[] __attribute__((counted_by(count)));
+///     };
+///
+/// If we have an expression like \p p->ptr->array[index], we want the
+/// \p MemberExpr for \p p->ptr instead of \p p.
+class StructAccessBase
+    : public ConstStmtVisitor<StructAccessBase, const Expr *> {
+  const RecordDecl *ExpectedRD;
+
+  bool IsExpectedRecordDecl(const Expr *E) const {
+    QualType Ty = E->getType();
+    if (Ty->isPointerType())
+      Ty = Ty->getPointeeType();
+    return ExpectedRD == Ty->getAsRecordDecl();
+  }
+
+public:
+  StructAccessBase(const RecordDecl *ExpectedRD) : ExpectedRD(ExpectedRD) {}
+
+  //===--------------------------------------------------------------------===//
+  //                            Visitor Methods
+  //===--------------------------------------------------------------------===//
+
+  // NOTE: If we build C++ support for counted_by, then we'll have to handle
+  // horrors like this:
+  //
+  //     struct S {
+  //       int x, y;
+  //       int blah[] __attribute__((counted_by(x)));
+  //     } s;
+  //
+  //     int foo(int index, int val) {
+  //       int (S::*IHatePMDs)[] = &S::blah;
+  //       (s.*IHatePMDs)[index] = val;
+  //     }
+
+  const Expr *Visit(const Expr *E) {
+    return ConstStmtVisitor<StructAccessBase, const Expr *>::Visit(E);
+  }
+
+  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
+
+  // These are the types we expect to return (in order of most to least
+  // likely):
+  //
+  //   1. DeclRefExpr - This is the expression for the base of the structure.
+  //      It's exactly what we want to build an access to the \p counted_by
+  //      field.
+  //   2. MemberExpr - This is the expression that has the same \p RecordDecl
+  //      as the flexble array member's lexical enclosing \p RecordDecl. This
+  //      allows us to catch things like: "p->p->array"
+  //   3. CompoundLiteralExpr - This is for people who create something
+  //      heretical like (struct foo has a flexible array member):
+  //
+  //        (struct foo){ 1, 2 }.blah[idx];
+  const Expr *VisitDeclRefExpr(const DeclRefExpr *E) {
+    return IsExpectedRecordDecl(E) ? E : nullptr;
+  }
+  const Expr *VisitMemberExpr(const MemberExpr *E) {
+    if (IsExpectedRecordDecl(E) && E->isArrow())
+      return E;
+    const Expr *Res = Visit(E->getBase());
+    return !Res && IsExpectedRecordDecl(E) ? E : Res;
+  }
+  const Expr *VisitCompoundLiteralExpr(const CompoundLiteralExpr *E) {
+    return IsExpectedRecordDecl(E) ? E : nullptr;
+  }
+  const Expr *VisitCallExpr(const CallExpr *E) {
+    return IsExpectedRecordDecl(E) ? E : nullptr;
+  }
+
+  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
+    if (IsExpectedRecordDecl(E))
+      return E;
+    return Visit(E->getBase());
+  }
+  const Expr *VisitCastExpr(const CastExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitParenExpr(const ParenExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryAddrOf(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryDeref(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+};
+
+} // end anonymous namespace
+
+using RecIndicesTy =
+    SmallVector<std::pair<const RecordDecl *, llvm::Value *>, 8>;
+
+static bool getGEPIndicesToField(CodeGenFunction &CGF, const RecordDecl *RD,
+                                 const FieldDecl *FD, RecIndicesTy &Indices) {
+  const CGRecordLayout &Layout = CGF.CGM.getTypes().getCGRecordLayout(RD);
+  int64_t FieldNo = -1;
+  for (const Decl *D : RD->decls()) {
+    if (const auto *Field = dyn_cast<FieldDecl>(D)) {
+      FieldNo = Layout.getLLVMFieldNo(Field);
+      if (FD == Field) {
+        Indices.emplace_back(std::make_pair(RD, CGF.Builder.getInt32(FieldNo)));
+        return true;
+      }
+    }
+
+    if (const auto *Record = dyn_cast<RecordDecl>(D)) {
+      ++FieldNo;
+      if (getGEPIndicesToField(CGF, Record, FD, Indices)) {
+        Indices.emplace_back(std::make_pair(RD, CGF.Builder.getInt32(FieldNo)));
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+llvm::Value *CodeGenFunction::EmitCountedByFieldExpr(
+    const Expr *Base, const FieldDecl *FAMDecl, const FieldDecl *CountDecl) {
+  // This method is typically called in contexts where we can't generate
+  // side-effects, like in __builtin_dynamic_object_size. When finding
+  // expressions, only choose those that have either already been emitted or
+  // can be loaded without side-effects.
+  const RecordDecl *RD = CountDecl->getParent()->getOuterLexicalRecordContext();
+
+  // Find the base struct expr (i.e. p in p->a.b.c.d).
+  const Expr *StructBase = StructAccessBase(RD).Visit(Base);
+  if (!StructBase)
+    return nullptr;
+
+  llvm::Value *Res = nullptr;
+  if (const auto *DRE = dyn_cast<DeclRefExpr>(StructBase)) {
+    Res = EmitDeclRefLValue(DRE).getPointer(*this);
+
+    const Type *Ty = DRE->getType().getTypePtr();
+    Res = Builder.CreateAlignedLoad(ConvertType(QualType(Ty, 0)), Res,
+                                    getPointerAlign(), "dre.load");
+  } else {
+    Address Addr = Address::invalid();
+    auto I = LocalDeclMap.find(FAMDecl);
+    if (I != LocalDeclMap.end()) {
+      Res = I->second.getPointer();
+      Addr = Address(Res, ConvertType(FAMDecl->getType()), getPointerAlign());
+    }
+
+    bool NeedLoad = true;
+    if (!Res && !StructBase->HasSideEffects(getContext())) {
+      LValueBaseInfo EltBaseInfo;
+      TBAAAccessInfo EltTBAAInfo;
+      Addr = EmitPointerWithAlignment(StructBase, &EltBaseInfo, &EltTBAAInfo);
+      NeedLoad = false;
+    }
+
+    if (!Addr.isValid())
+      return nullptr;
+
+    if (NeedLoad)
+      Res = Builder.CreateLoad(Addr, /*isVolatile=*/false, "struct.load");
+
+    Res = Addr.getPointer();
+  }
+
+  if (!Res)
+    return nullptr;
+
+  llvm::Value *Zero = Builder.getInt32(0);
+  RecIndicesTy Indices;
+
+  getGEPIndicesToField(*this, RD, CountDecl, Indices);
+
+  for (auto I = Indices.rbegin(), E = Indices.rend(); I != E; ++I)
+    Res = Builder.CreateInBoundsGEP(
+        ConvertType(QualType(I->first->getTypeForDecl(), 0)), Res,
+        {Zero, I->second}, "..counted_by.gep");
+
+  return Builder.CreateAlignedLoad(ConvertType(CountDecl->getType()), Res,
+                                   getIntAlign(), "..counted_by.load");
+}
+
+const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
+  if (!FD || !FD->hasAttr<CountedByAttr>())
+    return nullptr;
+
+  const auto *CBA = FD->getAttr<CountedByAttr>();
+  if (!CBA)
+    return nullptr;
+
+  auto GetNonAnonStructOrUnion = [](const RecordDecl *RD) {
+    while (RD && !RD->getDeclName())
+      if (const auto *R = dyn_cast<RecordDecl>(RD->getDeclContext()))
+        RD = R;
+      else
+        break;
----------------
rapidsna wrote:

I see how it can potentially hit the condition based on your example above. I was actually more on whether returning `RD` is the right behavior for `if (!R)` because it will still return some anonymous struct while `GetNonAnonStructOrUnion` suggests to return non-anonymous struct. But given your comment, I guess that's what you intended in order to possibly cover cases like below right?

```
int foo(void) { 
  struct { 
    int count;
    int array[] __counted_by(count); 
  } x = { 3 }; 
  ... 
}
```

If so, could you add a comment that this functor may return an anonymous struct in this special case?

https://github.com/llvm/llvm-project/pull/76348


More information about the cfe-commits mailing list