[clang] 00ecacc - [HLSL] Generate buffer subscript operators

Chris Bieneman via cfe-commits cfe-commits at lists.llvm.org
Fri Sep 2 12:55:58 PDT 2022


Author: Chris Bieneman
Date: 2022-09-02T14:55:43-05:00
New Revision: 00ecacca7d90f96a1d54bc3fa38986fdd64e4c72

URL: https://github.com/llvm/llvm-project/commit/00ecacca7d90f96a1d54bc3fa38986fdd64e4c72
DIFF: https://github.com/llvm/llvm-project/commit/00ecacca7d90f96a1d54bc3fa38986fdd64e4c72.diff

LOG: [HLSL] Generate buffer subscript operators

In HLSL buffer types support array subscripting syntax for loads and
stores. This change fleshes out the subscript operators to become array
accesses on the underlying handle pointer. This will allow LLVM
optimization passes to optimize resource accesses the same way any other
memory access would be optimized.

Reviewed By: aaron.ballman

Differential Revision: https://reviews.llvm.org/D131268

Added: 
    clang/test/CodeGenHLSL/buffer-array-operator.hlsl

Modified: 
    clang/lib/Sema/HLSLExternalSemaSource.cpp
    clang/lib/Sema/SemaType.cpp
    clang/test/AST/HLSL/RWBuffer-AST.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index fe963fdbf2781..ee3aa4d42a049 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -104,7 +104,14 @@ struct BuiltinTypeDeclBuilder {
 
   BuiltinTypeDeclBuilder &
   addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
-    return addMemberVariable("h", Record->getASTContext().VoidPtrTy, Access);
+    QualType Ty = Record->getASTContext().VoidPtrTy;
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0)))
+        Ty = Record->getASTContext().getPointerType(
+            QualType(TTD->getTypeForDecl(), 0));
+    }
+    return addMemberVariable("h", Ty, Access);
   }
 
   BuiltinTypeDeclBuilder &
@@ -158,15 +165,25 @@ struct BuiltinTypeDeclBuilder {
         lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
 
     Expr *RCExpr = emitResourceClassExpr(AST, RC);
-    CallExpr *Call =
-        CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
-                         SourceLocation(), FPOptionsOverride());
+    Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
+                                  SourceLocation(), FPOptionsOverride());
 
     CXXThisExpr *This = new (AST)
         CXXThisExpr(SourceLocation(), Constructor->getThisType(), true);
-    MemberExpr *Handle = MemberExpr::CreateImplicit(
-        AST, This, true, Fields["h"], Fields["h"]->getType(), VK_LValue,
-        OK_Ordinary);
+    Expr *Handle = MemberExpr::CreateImplicit(AST, This, true, Fields["h"],
+                                              Fields["h"]->getType(), VK_LValue,
+                                              OK_Ordinary);
+
+    // If the handle isn't a void pointer, cast the builtin result to the
+    // correct type.
+    if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
+      Call = CXXStaticCastExpr::Create(
+          AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
+          AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
+          FPOptionsOverride(), SourceLocation(), SourceLocation(),
+          SourceRange());
+    }
+
     BinaryOperator *Assign = BinaryOperator::Create(
         AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
         SourceLocation(), FPOptionsOverride());
@@ -179,6 +196,85 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
+    addArraySubscriptOperator(true);
+    addArraySubscriptOperator(false);
+    return *this;
+  }
+
+  BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
+    assert(Fields.count("h") > 0 &&
+           "Subscript operator must be added after the handle.");
+
+    FieldDecl *Handle = Fields["h"];
+    ASTContext &AST = Record->getASTContext();
+
+    assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
+           "Not yet supported for void pointer handles.");
+
+    QualType ElemTy =
+        QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
+    QualType ReturnTy = ElemTy;
+
+    FunctionProtoType::ExtProtoInfo ExtInfo;
+
+    // Subscript operators return references to elements, const makes the
+    // reference and method const so that the underlying data is not mutable.
+    ReturnTy = AST.getLValueReferenceType(ReturnTy);
+    if (IsConst) {
+      ExtInfo.TypeQuals.addConst();
+      ReturnTy.addConst();
+    }
+
+    QualType MethodTy =
+        AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    auto *MethodDecl = CXXMethodDecl::Create(
+        AST, Record, SourceLocation(),
+        DeclarationNameInfo(
+            AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
+            SourceLocation()),
+        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
+        SourceLocation());
+
+    IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
+    auto *IdxParam = ParmVarDecl::Create(
+        AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
+        &II, AST.UnsignedIntTy,
+        AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
+        SC_None, nullptr);
+    MethodDecl->setParams({IdxParam});
+
+    // Also add the parameter to the function prototype.
+    auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    FnProtoLoc.setParam(0, IdxParam);
+
+    auto *This = new (AST)
+        CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true);
+    auto *HandleAccess = MemberExpr::CreateImplicit(
+        AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
+
+    auto *IndexExpr = DeclRefExpr::Create(
+        AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
+        DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
+        AST.UnsignedIntTy, VK_PRValue);
+
+    auto *Array =
+        new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
+                                     OK_Ordinary, SourceLocation());
+
+    auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
+
+    MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
+                                             SourceLocation(),
+                                             SourceLocation()));
+    MethodDecl->setLexicalDeclContext(Record);
+    MethodDecl->setAccess(AccessSpecifier::AS_public);
+    Record->addDecl(MethodDecl);
+
+    return *this;
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     Record->startDefinition();
     return *this;
@@ -368,6 +464,7 @@ void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
   BuiltinTypeDeclBuilder(Record)
       .addHandleMember()
       .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
+      .addArraySubscriptOperators()
       .annotateResourceClass(HLSLResourceAttr::UAV)
       .completeDefinition();
 }

diff  --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 313a534a1246e..e87b59d819988 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2174,7 +2174,7 @@ QualType Sema::BuildPointerType(QualType T,
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0;
     return QualType();
   }
@@ -2244,7 +2244,7 @@ QualType Sema::BuildReferenceType(QualType T, bool SpelledAsLValue,
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 1;
     return QualType();
   }
@@ -3008,7 +3008,7 @@ QualType Sema::BuildMemberPointerType(QualType T, QualType Class,
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0;
     return QualType();
   }

diff  --git a/clang/test/AST/HLSL/RWBuffer-AST.hlsl b/clang/test/AST/HLSL/RWBuffer-AST.hlsl
index c9cbd730933fe..193ef67e152b7 100644
--- a/clang/test/AST/HLSL/RWBuffer-AST.hlsl
+++ b/clang/test/AST/HLSL/RWBuffer-AST.hlsl
@@ -39,11 +39,30 @@ RWBuffer<float> Buffer;
 
 // CHECK: FinalAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit final
 // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit UAV
-// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> implicit h 'void *'
+// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> implicit h 'element_type *'
+
+// CHECK: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> operator[] 'element_type &const (unsigned int) const'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Idx 'unsigned int'
+// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' lvalue
+// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}}
+// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'const RWBuffer<element_type> *' implicit this
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
+
+// CHECK-NEXT: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> operator[] 'element_type &(unsigned int)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Idx 'unsigned int'
+// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' lvalue
+// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}}
+// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'RWBuffer<element_type> *' implicit this
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
+
 // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class RWBuffer definition
 
 // CHECK: TemplateArgument type 'float'
 // CHECK-NEXT: BuiltinType 0x{{[0-9A-Fa-f]+}} 'float'
 // CHECK-NEXT: FinalAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit final
 // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit UAV
-// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc>  implicit referenced h 'void *'
+// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc>  implicit referenced h 'float *'

diff  --git a/clang/test/CodeGenHLSL/buffer-array-operator.hlsl b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl
new file mode 100644
index 0000000000000..6bcb06106bf1c
--- /dev/null
+++ b/clang/test/CodeGenHLSL/buffer-array-operator.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+const RWBuffer<float> In;
+RWBuffer<float> Out;
+
+void fn(int Idx) {
+  Out[Idx] = In[Idx];
+}
+
+// This test is intended to verify reasonable code generation of the subscript
+// operator. In this test case we should be generating both the const and
+// non-const operators so we verify both cases.
+
+// Non-const comes first.
+// CHECK: ptr @"??A?$RWBuffer at M@hlsl@@QBAAAMI at Z"
+// CHECK: %this1 = load ptr, ptr %this.addr, align 4
+// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0
+// CHECK-NEXT: %0 = load ptr, ptr %h, align 4
+// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4
+// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1
+// CHECK-NEXT: ret ptr %arrayidx
+
+// Const comes next, and returns the pointer instead of the value.
+// CHECK: ptr @"??A?$RWBuffer at M@hlsl@@QAAAAMI at Z"
+// CHECK: %this1 = load ptr, ptr %this.addr, align 4
+// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0
+// CHECK-NEXT: %0 = load ptr, ptr %h, align 4
+// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4
+// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1
+// CHECK-NEXT: ret ptr %arrayidx


        


More information about the cfe-commits mailing list