[clang] [llvm] [HLSL] Add `Increment`/`DecrementCounter` methods to structured buffers (PR #114148)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 31 13:56:01 PDT 2024


================
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
----------------
llvm-beanz wrote:

Do we need to handle the case where `LastExpr->getType() != ReturnTy`?
If not, maybe we should add an assert something like:
```
assert(AST.hasSameUnqualifiedType(LastExpr->getType(), ReturnTy ) && "Should be the same type!");
```

https://github.com/llvm/llvm-project/pull/114148


More information about the cfe-commits mailing list