[clang] [HLSL] Implement `Append` and `Consume` methods on `Append`/`ConsumeStructuredBuffer` (PR #118536)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 3 11:35:39 PST 2024


https://github.com/hekota created https://github.com/llvm/llvm-project/pull/118536

The methods are using existing clang builtins `__builtin_hlsl_buffer_update_counter` and `__builtin_hlsl_resource_getpointer` to update the buffer counter and then load or store the value.

Fixes #112968

>From 97391c97c16029e2a3376ff2938ebc7e94393149 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 3 Dec 2024 11:13:47 -0800
Subject: [PATCH] [HLSL] Add Append and Consume methods on
 Append/ConsumeStructuredBuffer

---
 clang/lib/Sema/HLSLExternalSemaSource.cpp     | 58 ++++++++++++++++++-
 .../test/AST/HLSL/StructuredBuffers-AST.hlsl  | 46 ++++++++++++++-
 .../StructuredBuffers-methods-lib.hlsl        | 40 ++++++++++---
 3 files changed, 133 insertions(+), 11 deletions(-)

diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index f849e841de190a..882e0a8e89ba5a 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -231,6 +231,8 @@ struct BuiltinTypeDeclBuilder {
   BuiltinTypeDeclBuilder &addDecrementCounterMethod();
   BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
                                                   bool IsConst, bool IsRef);
+  BuiltinTypeDeclBuilder &addAppendMethod();
+  BuiltinTypeDeclBuilder &addConsumeMethod();
 };
 
 struct TemplateParameterListBuilder {
@@ -430,12 +432,20 @@ struct BuiltinTypeMethodBuilder {
   // Argument placeholders, inspired by std::placeholder. These are the indices
   // of arguments to forward to `callBuiltin`, and additionally `Handle` which
   // refers to the resource handle.
-  enum class PlaceHolder { _0, _1, _2, _3, Handle = 127 };
+  enum class PlaceHolder { _0, _1, _2, _3, Handle = 127, LastStmtPop = 128 };
 
   Expr *convertPlaceholder(PlaceHolder PH) {
     if (PH == PlaceHolder::Handle)
       return getResourceHandleExpr();
 
+    if (PH == PlaceHolder::LastStmtPop) {
+      assert(!StmtsList.empty() && "no statements in the list");
+      Stmt *LastStmt = StmtsList.pop_back_val();
+      assert(isa<ValueStmt>(LastStmt) &&
+             "last statement does not have a value");
+      return cast<ValueStmt>(LastStmt)->getExprStmt();
+    }
+
     ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
     ParmVarDecl *ParamDecl = Method->getParamDecl(static_cast<unsigned>(PH));
     return DeclRefExpr::Create(
@@ -558,6 +568,22 @@ struct BuiltinTypeMethodBuilder {
     return *this;
   }
 
+  template <typename T> BuiltinTypeMethodBuilder &assign(T RHS) {
+    Expr *RHSExpr = convertPlaceholder(RHS);
+
+    assert(!StmtsList.empty() && "Nothing to assign to");
+    Expr *LastExpr = dyn_cast<Expr>(StmtsList.back());
+    assert(LastExpr && "No expression to assign");
+
+    Stmt *AssignStmt = BinaryOperator::Create(
+        DeclBuilder.SemaRef.getASTContext(), LastExpr, RHSExpr, BO_Assign,
+        LastExpr->getType(), ExprValueKind::VK_PRValue,
+        ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride());
+    StmtsList.pop_back();
+    StmtsList.push_back(AssignStmt);
+    return *this;
+  }
+
   BuiltinTypeMethodBuilder &dereference() {
     assert(!StmtsList.empty() && "Nothing to dereference");
     ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
@@ -674,6 +700,34 @@ BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name,
       .finalizeMethod();
 }
 
+BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
+  using PH = BuiltinTypeMethodBuilder::PlaceHolder;
+  ASTContext &AST = SemaRef.getASTContext();
+  QualType ElemTy = getHandleElementType();
+  return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy)
+      .addParam("value", ElemTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
+                   PH::Handle, getConstantIntExpr(1))
+      .callBuiltin("__builtin_hlsl_resource_getpointer",
+                   AST.getPointerType(ElemTy), PH::Handle, PH::LastStmtPop)
+      .dereference()
+      .assign(PH::_0)
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() {
+  using PH = BuiltinTypeMethodBuilder::PlaceHolder;
+  ASTContext &AST = SemaRef.getASTContext();
+  QualType ElemTy = getHandleElementType();
+  return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
+                   PH::Handle, getConstantIntExpr(1))
+      .callBuiltin("__builtin_hlsl_resource_getpointer",
+                   AST.getPointerType(ElemTy), PH::Handle, PH::LastStmtPop)
+      .dereference()
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -900,6 +954,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
                     /*IsROV=*/false, /*RawBuffer=*/true)
+        .addAppendMethod()
         .completeDefinition();
   });
 
@@ -910,6 +965,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
                     /*IsROV=*/false, /*RawBuffer=*/true)
+        .addConsumeMethod()
         .completeDefinition();
   });
 
diff --git a/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl b/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl
index afee0e120afdb1..5280157aa42d1f 100644
--- a/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl
+++ b/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl
@@ -20,7 +20,7 @@
 //
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
 // RUN:   -DRESOURCE=AppendStructuredBuffer %s | FileCheck -DRESOURCE=AppendStructuredBuffer \
-// RUN:   -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
+// RUN:   -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-APPEND %s
 //
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
 // RUN:  -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
@@ -28,7 +28,7 @@
 //
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
 // RUN:   -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
-// RUN:   -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
+// RUN:   -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-CONSUME %s
 //
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
 // RUN:  -DRESOURCE=RasterizerOrderedStructuredBuffer %s | FileCheck -DRESOURCE=RasterizerOrderedStructuredBuffer \
@@ -135,6 +135,48 @@ RESOURCE<float> Buffer;
 // CHECK-COUNTER-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1
 // CHECK-COUNTER-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline
 
+// CHECK-APPEND: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Append 'void (element_type)'
+// CHECK-APPEND-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> value 'element_type'
+// CHECK-APPEND-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-APPEND-NEXT: BinaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' '='
+// CHECK-APPEND-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
+// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
+// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
+// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
+// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
+// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
+// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
+// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
+// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
+// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
+// CHECK-APPEND-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' 1
+// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' ParmVar 0x{{[0-9A-Fa-f]+}} 'value' 'element_type'
+
+// CHECK-CONSUME: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Consume 'element_type ()'
+// CHECK-CONSUME-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-CONSUME-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-CONSUME-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
+// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
+// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
+// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
+// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
+// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
+// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
+// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
+// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
+// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
+// CHECK-CONSUME-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' 1
+
 // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class [[RESOURCE]] definition
 
 // CHECK: TemplateArgument type 'float'
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
index b7986ae7dda1c2..bc65db503f8618 100644
--- a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -5,21 +5,45 @@
 
 RWStructuredBuffer<float> RWSB1 : register(u0);
 RWStructuredBuffer<float> RWSB2 : register(u1);
+AppendStructuredBuffer<float> ASB : register(u2);
+ConsumeStructuredBuffer<float> CSB : register(u3);
 
 // CHECK: %"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", float, 1, 0) }
 
-export void TestIncrementCounter() {
-    RWSB1.IncrementCounter();
+export int TestIncrementCounter() {
+    return RWSB1.IncrementCounter();
 }
 
-// CHECK: define void @_Z20TestIncrementCounterv()
-// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
+// CHECK: define noundef i32 @_Z20TestIncrementCounterv()
+// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
+// CHECK-DXIL: ret i32 %[[INDEX]]
+export int TestDecrementCounter() {
+    return RWSB2.DecrementCounter();
+}
+
+// CHECK: define noundef i32 @_Z20TestDecrementCounterv()
+// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
+// CHECK-DXIL: ret i32 %[[INDEX]]
+
+export void TestAppend(float value) {
+    ASB.Append(value);
+}
+
+// CHECK: define void @_Z10TestAppendf(float noundef %value)
+// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %value.addr, align 4
+// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
+// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i32 %[[INDEX]])
+// CHECK-DXIL: store float %[[VALUE]], ptr %[[RESPTR]], align 4
 
-export void TestDecrementCounter() {
-    RWSB2.DecrementCounter();
+export float TestConsume() {
+    return CSB.Consume();
 }
 
-// CHECK: define void @_Z20TestDecrementCounterv()
-// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
+// CHECK: define noundef float @_Z11TestConsumev()
+// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %1, i8 1)
+// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %0, i32 %[[INDEX]])
+// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %[[RESPTR]], align 4
+// CHECK-DXIL: ret float %[[VALUE]]
 
 // CHECK: declare i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i8)
+// CHECK: declare ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i32)



More information about the cfe-commits mailing list