[llvm-branch-commits] [clang] [HLSL] Add RWBuffer::Load(Index) (PR #117018)

Justin Bogner via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 20 09:58:21 PST 2024


https://github.com/bogner created https://github.com/llvm/llvm-project/pull/117018

This method is the same as `operator[]`, except that it returns a value instead of a reference.

>From 8c9d382ae6d3c6e2dc8a0d738a97f3b3d14d5413 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Tue, 19 Nov 2024 14:32:44 -0800
Subject: [PATCH] [HLSL] Add RWBuffer::Load(Index)

This method is the same as `operator[]`, except that it returns a value instead
of a reference.
---
 clang/lib/Sema/HLSLExternalSemaSource.cpp     | 35 ++++++++++++++-----
 clang/test/AST/HLSL/RWBuffer-AST.hlsl         | 15 ++++++++
 .../builtins/RWBuffer-subscript.hlsl          |  7 ++++
 3 files changed, 48 insertions(+), 9 deletions(-)

diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 4df2893473d474..ae026e369ea78e 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -189,12 +189,29 @@ struct BuiltinTypeDeclBuilder {
   BuiltinTypeDeclBuilder &addArraySubscriptOperators(Sema &S) {
     if (Record->isCompleteDefinition())
       return *this;
-    addArraySubscriptOperator(S, true);
-    addArraySubscriptOperator(S, false);
+    ASTContext &AST = Record->getASTContext();
+    DeclarationName Subscript =
+        AST.DeclarationNames.getCXXOperatorName(OO_Subscript);
+    addHandleAccessFunction(S, Subscript, /*IsConst=*/true, /*IsRef=*/true);
+    addHandleAccessFunction(S, Subscript, /*IsConst=*/false, /*IsRef=*/true);
+    return *this;
+  }
+
+  BuiltinTypeDeclBuilder &addLoadMethods(Sema &S) {
+    if (Record->isCompleteDefinition())
+      return *this;
+
+    ASTContext &AST = Record->getASTContext();
+    IdentifierInfo &II =
+        AST.Idents.get("Load", tok::TokenKind::identifier);
+    DeclarationName Load(&II);
+    addHandleAccessFunction(S, Load, /*IsConst=*/false, /*IsRef=*/false);
+
     return *this;
   }
 
-  BuiltinTypeDeclBuilder &addArraySubscriptOperator(Sema &S, bool IsConst) {
+  BuiltinTypeDeclBuilder &addHandleAccessFunction(Sema &S, DeclarationName Name,
+                                                  bool IsConst, bool IsRef) {
     if (Record->isCompleteDefinition())
       return *this;
 
@@ -216,18 +233,16 @@ struct BuiltinTypeDeclBuilder {
       ExtInfo.TypeQuals.addConst();
       ReturnTy.addConst();
     }
-    ReturnTy = AST.getLValueReferenceType(ReturnTy);
+    if (IsRef)
+      ReturnTy = AST.getLValueReferenceType(ReturnTy);
 
     QualType MethodTy =
         AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
     auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
     auto *MethodDecl = CXXMethodDecl::Create(
         AST, Record, SourceLocation(),
-        DeclarationNameInfo(
-            AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
-            SourceLocation()),
-        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
-        SourceLocation());
+        DeclarationNameInfo(Name, SourceLocation()), MethodTy, TSInfo, SC_None,
+        false, false, ConstexprSpecKind::Unspecified, SourceLocation());
 
     IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
     auto *IdxParam = ParmVarDecl::Create(
@@ -489,6 +504,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/false)
         .addArraySubscriptOperators(*SemaPtr)
+        .addLoadMethods(*SemaPtr)
         .completeDefinition();
   });
 
@@ -501,6 +517,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/true,
                     /*RawBuffer=*/false)
         .addArraySubscriptOperators(*SemaPtr)
+        .addLoadMethods(*SemaPtr)
         .completeDefinition();
   });
 
diff --git a/clang/test/AST/HLSL/RWBuffer-AST.hlsl b/clang/test/AST/HLSL/RWBuffer-AST.hlsl
index 6a207ba3d8a7d2..17043046b96cac 100644
--- a/clang/test/AST/HLSL/RWBuffer-AST.hlsl
+++ b/clang/test/AST/HLSL/RWBuffer-AST.hlsl
@@ -64,6 +64,21 @@ RWBuffer<float> Buffer;
 // CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
 // CHECK-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline
 
+// CHECK-NEXT: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Load 'element_type (unsigned int)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Idx 'unsigned int'
+// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
+// CHECK-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
+// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
+// CHECK-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
+// CHECK-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]
+// CHECK-SAME: ' lvalue .__handle 0x{{[0-9A-Fa-f]+}}
+// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'RWBuffer<element_type>' lvalue implicit this
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
+// CHECK-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline
+
 // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class RWBuffer definition
 
 // CHECK: TemplateArgument type 'float'
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-subscript.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-subscript.hlsl
index 8ce8417772530c..4428b77dd9ec8e 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-subscript.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-subscript.hlsl
@@ -6,9 +6,16 @@ RWBuffer<int> Out;
 [numthreads(1,1,1)]
 void main(unsigned GI : SV_GroupIndex) {
   // CHECK: define void @main()
+
   // CHECK: %[[INPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
   // CHECK: %[[LOAD:.*]] = load i32, ptr %[[INPTR]]
   // CHECK: %[[OUTPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
   // CHECK: store i32 %[[LOAD]], ptr %[[OUTPTR]]
   Out[GI] = In[GI];
+
+  // CHECK: %[[INPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
+  // CHECK: %[[LOAD:.*]] = load i32, ptr %[[INPTR]]
+  // CHECK: %[[OUTPTR:.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @llvm.dx.resource.getpointer.p0.tdx.TypedBuffer_i32_1_0_1t(target("dx.TypedBuffer", i32, 1, 0, 1) %{{.*}}, i32 %{{.*}})
+  // CHECK: store i32 %[[LOAD]], ptr %[[OUTPTR]]
+  Out[GI] = In.Load(GI);
 }



More information about the llvm-branch-commits mailing list