[clang] [llvm] [HLSL] Add `Increment`/`DecrementCounter` methods to structured buffers (PR #114148)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Thu Nov 14 00:43:40 PST 2024


================
@@ -343,27 +336,224 @@ struct TemplateParameterListBuilder {
     Params.clear();
 
     QualType T = Builder.Template->getInjectedClassNameSpecialization();
-    T = S.Context.getInjectedClassNameType(Builder.Record, T);
+    T = AST.getInjectedClassNameType(Builder.Record, T);
 
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II = DeclBuilder.S.getASTContext().Idents.get(
+        Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = DeclBuilder.S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = DeclBuilder.S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+
+    // The first statement added to a method creates the declaration.
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = DeclBuilder.S.getASTContext();
+    FunctionDecl *FD = lookupBuiltinFunction(DeclBuilder.S, BuiltinName);
+    DeclRefExpr *DRE = DeclRefExpr::Create(
+        AST, NestedNameSpecifierLoc(), SourceLocation(), FD, false,
+        FD->getNameInfo(), FD->getType(), VK_PRValue);
+
+    SmallVector<Expr *> NewCallParms;
+    if (AddResourceHandleAsFirstArg) {
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+    }
+
+    Expr *Call = CallExpr::Create(
+        AST, DRE, AddResourceHandleAsFirstArg ? NewCallParms : CallParms,
+        FD->getReturnType(), VK_PRValue, SourceLocation(), FPOptionsOverride());
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    assert(!DeclBuilder.Record->isCompleteDefinition() &&
+           "record is already complete");
+    assert(
+        Method != nullptr &&
+        "method decl not created; are you missing a call to build the body?");
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = DeclBuilder.S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          assert(AST.hasSameUnqualifiedType(
+                     isa<CallExpr>(LastExpr)
+                         ? cast<CallExpr>(LastExpr)->getCallReturnType(AST)
+                         : LastExpr->getType(),
+                     ReturnTy) &&
+                 "Return type of the last statement must match the return type "
+                 "of the method");
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
-TemplateParameterListBuilder
-BuiltinTypeDeclBuilder::addTemplateArgumentList(Sema &S) {
-  return TemplateParameterListBuilder(S, *this);
+TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
+  return TemplateParameterListBuilder(*this);
 }
 
 BuiltinTypeDeclBuilder &
-BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
-                                                ArrayRef<StringRef> Names) {
-  TemplateParameterListBuilder Builder = this->addTemplateArgumentList(S);
+BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names) {
+  if (Record->isCompleteDefinition()) {
+    assert(Template && "existing record it not a template");
+    assert(Template->getTemplateParameters()->size() == Names.size() &&
+           "template param count mismatch");
+    return *this;
+  }
+
+  TemplateParameterListBuilder Builder = this->addTemplateArgumentList();
   for (StringRef Name : Names)
     Builder.addTypeParameter(Name);
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() {
+  ASTContext &AST = S.getASTContext();
+  Expr *One = IntegerLiteral::Create(
+      AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1, true), AST.IntTy,
+      SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
----------------
hekota wrote:

See my comments above.

https://github.com/llvm/llvm-project/pull/114148


More information about the cfe-commits mailing list