[clang] [HLSL] Implement output parameter (PR #101083)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Jul 29 14:07:06 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-clang-static-analyzer-1
Author: Chris B (llvm-beanz)
<details>
<summary>Changes</summary>
HLSL output parameters are denoted with the `inout` and `out` keywords in the function declaration. When an argument to an output parameter is constructed a temporary value is constructed for the argument.
For `inout` pamameters the argument is intialized by casting the argument expression to the parameter type. For `out` parameters the argument is not initialized before the call.
In both cases on return of the function the temporary value is written back to the argument lvalue expression through an optional casting sequence if required.
This change introduces a new HLSLOutArgExpr ast node which represents the output argument behavior. The OutArgExpr has two defined children: the base expresion and the writeback expression. The writeback expression will either be or contain an OpaqueValueExpr child expression which is used during code generation to represent the temporary value.
---
Patch is 62.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101083.diff
43 Files Affected:
- (modified) clang/include/clang/AST/ASTContext.h (+9)
- (modified) clang/include/clang/AST/Attr.h (+24-14)
- (modified) clang/include/clang/AST/Expr.h (+63)
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+3)
- (modified) clang/include/clang/AST/TextNodeDumper.h (+1)
- (modified) clang/include/clang/Basic/Attr.td (+1-2)
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
- (modified) clang/include/clang/Basic/Specifiers.h (+6)
- (modified) clang/include/clang/Basic/StmtNodes.td (+3)
- (modified) clang/include/clang/Sema/SemaHLSL.h (+2)
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+3)
- (modified) clang/lib/AST/ASTContext.cpp (+14)
- (modified) clang/lib/AST/Expr.cpp (+11)
- (modified) clang/lib/AST/ExprClassification.cpp (+1)
- (modified) clang/lib/AST/ExprConstant.cpp (+1)
- (modified) clang/lib/AST/ItaniumMangle.cpp (+12)
- (modified) clang/lib/AST/StmtPrinter.cpp (+4)
- (modified) clang/lib/AST/StmtProfile.cpp (+4)
- (modified) clang/lib/AST/TextNodeDumper.cpp (+4)
- (modified) clang/lib/AST/TypePrinter.cpp (+15-5)
- (modified) clang/lib/CodeGen/CGCall.cpp (+56)
- (modified) clang/lib/CodeGen/CGCall.h (+16-2)
- (modified) clang/lib/CodeGen/CGExpr.cpp (+26)
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+2)
- (modified) clang/lib/Sema/SemaChecking.cpp (+13)
- (modified) clang/lib/Sema/SemaDecl.cpp (+4)
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
- (modified) clang/lib/Sema/SemaExpr.cpp (+15-4)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+59-1)
- (modified) clang/lib/Sema/SemaOverload.cpp (+4)
- (modified) clang/lib/Sema/SemaSwift.cpp (+3)
- (modified) clang/lib/Sema/SemaType.cpp (+2)
- (modified) clang/lib/Sema/TreeTransform.h (+29)
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+21-1)
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+12)
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+2-1)
- (added) clang/test/AST/HLSL/OutArgExpr.hlsl (+65)
- (added) clang/test/CodeGenHLSL/BasicFeatures/OutputArguments.hlsl (+128)
- (added) clang/test/SemaHLSL/Language/OutputParameters.hlsl (+34)
- (added) clang/test/SemaHLSL/Language/TemplateOutArg.hlsl (+169)
- (modified) clang/test/SemaHLSL/parameter_modifiers.hlsl (+3-5)
- (modified) clang/test/SemaHLSL/parameter_modifiers_ast.hlsl (+10-10)
- (modified) clang/tools/libclang/CXCursor.cpp (+1)
``````````diff
diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 6d1c8ca8a2f96..534ffd994cd67 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1377,6 +1377,15 @@ class ASTContext : public RefCountedBase<ASTContext> {
/// in the return type and parameter types.
bool hasSameFunctionTypeIgnoringPtrSizes(QualType T, QualType U);
+
+ /// Get or construct a function type that is equivalent to the input type
+ /// except that the parameter ABI annotations are stripped.
+ QualType getFunctionTypeWithoutParamABIs(QualType T);
+
+ /// Determine if two function types are the same, ignoring parameter ABI
+ /// annotations.
+ bool hasSameFunctionTypeIgnoringParamABI(QualType T, QualType U);
+
/// Return the uniqued reference to the type for a complex
/// number with the specified element type.
QualType getComplexType(QualType T) const;
diff --git a/clang/include/clang/AST/Attr.h b/clang/include/clang/AST/Attr.h
index 8e9b7ad8b4682..00e3c9d9ab347 100644
--- a/clang/include/clang/AST/Attr.h
+++ b/clang/include/clang/AST/Attr.h
@@ -224,20 +224,7 @@ class ParameterABIAttr : public InheritableParamAttr {
InheritEvenIfAlreadyPresent) {}
public:
- ParameterABI getABI() const {
- switch (getKind()) {
- case attr::SwiftContext:
- return ParameterABI::SwiftContext;
- case attr::SwiftAsyncContext:
- return ParameterABI::SwiftAsyncContext;
- case attr::SwiftErrorResult:
- return ParameterABI::SwiftErrorResult;
- case attr::SwiftIndirectResult:
- return ParameterABI::SwiftIndirectResult;
- default:
- llvm_unreachable("bad parameter ABI attribute kind");
- }
- }
+ ParameterABI getABI() const;
static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstParameterABIAttr &&
@@ -379,6 +366,29 @@ inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &DB,
DB.AddTaggedVal(reinterpret_cast<uint64_t>(At), DiagnosticsEngine::ak_attr);
return DB;
}
+
+inline ParameterABI ParameterABIAttr::getABI() const {
+ switch (getKind()) {
+ case attr::SwiftContext:
+ return ParameterABI::SwiftContext;
+ case attr::SwiftAsyncContext:
+ return ParameterABI::SwiftAsyncContext;
+ case attr::SwiftErrorResult:
+ return ParameterABI::SwiftErrorResult;
+ case attr::SwiftIndirectResult:
+ return ParameterABI::SwiftIndirectResult;
+ case attr::HLSLParamModifier: {
+ const auto *A = cast<HLSLParamModifierAttr>(this);
+ if (A->isOut())
+ return ParameterABI::HLSLOut;
+ if (A->isInOut())
+ return ParameterABI::HLSLInOut;
+ return ParameterABI::Ordinary;
+ }
+ default:
+ llvm_unreachable("bad parameter ABI attribute kind");
+ }
+}
} // end namespace clang
#endif
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 5b813bfc2faf9..039719fc1c20d 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -7061,6 +7061,69 @@ class ArraySectionExpr : public Expr {
void setRBracketLoc(SourceLocation L) { RBracketLoc = L; }
};
+/// This class represents temporary values used to represent inout and out
+/// arguments in HLSL. From the callee perspective these parameters are more or
+/// less __restrict__ T&. They are guaranteed to not alias any memory. inout
+/// parameters are initialized by the caller, and out parameters are references
+/// to uninitialized memory.
+///
+/// In the caller, the argument expression creates a temporary in local memory
+/// and the address of the temporary is passed into the callee. There may be
+/// implicit conversion sequences to initialize the temporary, and on expiration
+/// of the temporary an inverse conversion sequence is applied as a write-back
+/// conversion to the source l-value.
+class HLSLOutArgExpr : public Expr {
+ friend class ASTStmtReader;
+
+ Expr *Base;
+ Expr *Writeback;
+ OpaqueValueExpr *OpaqueVal;
+ bool IsInOut;
+
+ HLSLOutArgExpr(QualType Ty, Expr *B, Expr *WB, OpaqueValueExpr *OpV,
+ bool IsInOut)
+ : Expr(HLSLOutArgExprClass, Ty, VK_LValue, OK_Ordinary), Base(B),
+ Writeback(WB), OpaqueVal(OpV), IsInOut(IsInOut) {
+ assert(!Ty->isDependentType() && "HLSLOutArgExpr given a dependent type!");
+ }
+
+ explicit HLSLOutArgExpr(EmptyShell Shell)
+ : Expr(HLSLOutArgExprClass, Shell) {}
+
+public:
+ static HLSLOutArgExpr *Create(const ASTContext &C, QualType Ty, Expr *Base,
+ bool IsInOut, Expr *WB,
+ OpaqueValueExpr *OpV);
+ static HLSLOutArgExpr *CreateEmpty(const ASTContext &Ctx);
+
+ const Expr *getBase() const { return Base; }
+ Expr *getBase() { return Base; }
+
+ const Expr *getWriteback() const { return Writeback; }
+ Expr *getWriteback() { return Writeback; }
+
+ const OpaqueValueExpr *getOpaqueValue() const { return OpaqueVal; }
+ OpaqueValueExpr *getOpaqueValue() { return OpaqueVal; }
+
+ bool isInOut() const { return IsInOut; }
+
+ SourceLocation getBeginLoc() const LLVM_READONLY {
+ return Base->getBeginLoc();
+ }
+
+ SourceLocation getEndLoc() const LLVM_READONLY { return Base->getEndLoc(); }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == HLSLOutArgExprClass;
+ }
+
+ // Iterators
+ child_range children() {
+ return child_range((Stmt**)&Base, ((Stmt**)&Writeback) + 1);
+ }
+};
+
+
/// Frontend produces RecoveryExprs on semantic errors that prevent creating
/// other well-formed expressions. E.g. when type-checking of a binary operator
/// fails, we cannot produce a BinaryOperator expression. Instead, we can choose
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e3c0cb46799f7..27c29099c57cf 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -4014,6 +4014,9 @@ DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
DEF_TRAVERSE_STMT(OpenACCLoopConstruct,
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
+// Traverse HLSL: Out argument expression
+DEF_TRAVERSE_STMT(HLSLOutArgExpr, {})
+
// FIXME: look at the following tricky-seeming exprs to see if we
// need to recurse on anything. These are ones that have methods
// returning decls or qualtypes or nestednamespecifier -- though I'm
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 39dd1f515c9eb..261853343a011 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -407,6 +407,7 @@ class TextNodeDumper
void
VisitLifetimeExtendedTemporaryDecl(const LifetimeExtendedTemporaryDecl *D);
void VisitHLSLBufferDecl(const HLSLBufferDecl *D);
+ void VisitHLSLOutArgExpr(const HLSLOutArgExpr *E);
void VisitOpenACCConstructStmt(const OpenACCConstructStmt *S);
void VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S);
void VisitEmbedExpr(const EmbedExpr *S);
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 46d0a66d59c37..6186161e6b182 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4613,14 +4613,13 @@ def HLSLGroupSharedAddressSpace : TypeAttr {
let Documentation = [HLSLGroupSharedAddressSpaceDocs];
}
-def HLSLParamModifier : TypeAttr {
+def HLSLParamModifier : ParameterABIAttr {
let Spellings = [CustomKeyword<"in">, CustomKeyword<"inout">, CustomKeyword<"out">];
let Accessors = [Accessor<"isIn", [CustomKeyword<"in">]>,
Accessor<"isInOut", [CustomKeyword<"inout">]>,
Accessor<"isOut", [CustomKeyword<"out">]>,
Accessor<"isAnyOut", [CustomKeyword<"out">, CustomKeyword<"inout">]>,
Accessor<"isAnyIn", [CustomKeyword<"in">, CustomKeyword<"inout">]>];
- let Subjects = SubjectList<[ParmVar]>;
let Documentation = [HLSLParamQualifierDocs];
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
}
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..c499ee8ac5906 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12357,6 +12357,8 @@ def warn_hlsl_availability : Warning<
def warn_hlsl_availability_unavailable :
Warning<err_unavailable.Summary>,
InGroup<HLSLAvailability>, DefaultError;
+def error_hlsl_inout_scalar_extension : Error<"illegal scalar extension cast on argument %0 to %select{|in}1out paramemter">;
+def error_hlsl_inout_lvalue : Error<"cannot bind non-lvalue argument %0 to %select{|in}1out paramemter">;
def err_hlsl_export_not_on_function : Error<
"export declaration can only be used on functions">;
diff --git a/clang/include/clang/Basic/Specifiers.h b/clang/include/clang/Basic/Specifiers.h
index fb11e8212f8b6..0ffd9e06cf3e5 100644
--- a/clang/include/clang/Basic/Specifiers.h
+++ b/clang/include/clang/Basic/Specifiers.h
@@ -382,6 +382,12 @@ namespace clang {
/// Swift asynchronous context-pointer ABI treatment. There can be at
/// most one parameter on a given function that uses this treatment.
SwiftAsyncContext,
+
+ // This parameter is a copy-out HLSL parameter.
+ HLSLOut,
+
+ // This parameter is a copy-in/copy-out HLSL parameter.
+ HLSLInOut,
};
/// Assigned inheritance model for a class in the MS C++ ABI. Must match order
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 9bf23fae50a9e..a80601b1e4a94 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -306,3 +306,6 @@ def OpenACCAssociatedStmtConstruct
: StmtNode<OpenACCConstructStmt, /*abstract=*/1>;
def OpenACCComputeConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
def OpenACCLoopConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
+
+// HLSL Constructs.
+def HLSLOutArgExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 2ddbee67c414b..64b565787f325 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -61,6 +61,8 @@ class SemaHLSL : public SemaBase {
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL);
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+
+ ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
};
} // namespace clang
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 5dd0ba33f8a9c..c19d750d30d56 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1988,6 +1988,9 @@ enum StmtCode {
// OpenACC Constructs
STMT_OPENACC_COMPUTE_CONSTRUCT,
STMT_OPENACC_LOOP_CONSTRUCT,
+
+ // HLSL Constructs
+ EXPR_HLSL_OUT_ARG,
};
/// The kinds of designators that can occur in a
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a465cdfcf3c89..750928fc00928 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3590,6 +3590,20 @@ bool ASTContext::hasSameFunctionTypeIgnoringPtrSizes(QualType T, QualType U) {
getFunctionTypeWithoutPtrSizes(U));
}
+QualType ASTContext::getFunctionTypeWithoutParamABIs(QualType T) {
+ if (const auto *Proto = T->getAs<FunctionProtoType>()) {
+ FunctionProtoType::ExtProtoInfo EPI = Proto->getExtProtoInfo();
+ EPI.ExtParameterInfos = nullptr;
+ return getFunctionType(Proto->getReturnType(), Proto->param_types(), EPI);
+ }
+ return T;
+}
+
+bool ASTContext::hasSameFunctionTypeIgnoringParamABI(QualType T, QualType U) {
+ return hasSameType(T, U) || hasSameType(getFunctionTypeWithoutParamABIs(T),
+ getFunctionTypeWithoutParamABIs(U));
+}
+
void ASTContext::adjustExceptionSpec(
FunctionDecl *FD, const FunctionProtoType::ExceptionSpecInfo &ESI,
bool AsWritten) {
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 9d5b8167d0ee6..be12e6e93cc45 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3631,6 +3631,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
case RequiresExprClass:
case SYCLUniqueStableNameExprClass:
case PackIndexingExprClass:
+ case HLSLOutArgExprClass:
// These never have a side-effect.
return false;
@@ -5318,3 +5319,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
alignof(OMPIteratorExpr));
return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
}
+
+HLSLOutArgExpr *HLSLOutArgExpr::Create(const ASTContext &C, QualType Ty,
+ Expr *Base, bool IsInOut, Expr *WB,
+ OpaqueValueExpr *OpV) {
+ return new (C) HLSLOutArgExpr(Ty, Base, WB, OpV, IsInOut);
+}
+
+HLSLOutArgExpr *HLSLOutArgExpr::CreateEmpty(const ASTContext &C) {
+ return new (C) HLSLOutArgExpr(EmptyShell());
+}
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index 6482cb6d39acc..ebbfaa187263f 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
case Expr::ArraySectionExprClass:
case Expr::OMPArrayShapingExprClass:
case Expr::OMPIteratorExprClass:
+ case Expr::HLSLOutArgExprClass:
return Cl::CL_LValue;
// C99 6.5.2.5p5 says that compound literals are lvalues.
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 558e20ed3e423..c692d47ffd1af 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -16469,6 +16469,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
case Expr::CoyieldExprClass:
case Expr::SYCLUniqueStableNameExprClass:
case Expr::CXXParenListInitExprClass:
+ case Expr::HLSLOutArgExprClass:
return ICEDiag(IK_NotICE, E->getBeginLoc());
case Expr::InitListExprClass: {
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index d46d621d4c7d4..1a1c316b90c4e 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3507,6 +3507,12 @@ CXXNameMangler::mangleExtParameterInfo(FunctionProtoType::ExtParameterInfo PI) {
case ParameterABI::Ordinary:
break;
+ // HLSL parameter mangling.
+ case ParameterABI::HLSLOut:
+ case ParameterABI::HLSLInOut:
+ mangleVendorQualifier(getParameterABISpelling(PI.getABI()));
+ break;
+
// All of these start with "swift", so they come before "ns_consumed".
case ParameterABI::SwiftContext:
case ParameterABI::SwiftAsyncContext:
@@ -5703,6 +5709,12 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
Out << "E";
break;
}
+ case Expr::HLSLOutArgExprClass: {
+ const auto *OAE = cast<clang::HLSLOutArgExpr>(E);
+ Out << (OAE->isInOut() ? "_inout_" : "_out_");
+ mangleType(E->getType());
+ break;
+ }
}
if (AsTemplateArg && !IsPrimaryExpr)
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 69e0b763e8ddc..0a2ac8c16671a 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -2799,6 +2799,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
OS << ")";
}
+void StmtPrinter::VisitHLSLOutArgExpr(HLSLOutArgExpr *Node) {
+ PrintExpr(Node->getBase());
+}
+
//===----------------------------------------------------------------------===//
// Stmt method implementations
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 89d2a422509d8..37812812ee2b3 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2631,6 +2631,10 @@ void StmtProfiler::VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S) {
P.VisitOpenACCClauseList(S->clauses());
}
+void StmtProfiler::VisitHLSLOutArgExpr(const HLSLOutArgExpr *S) {
+ VisitStmt(S);
+}
+
void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
bool Canonical, bool ProfileLambdaExpr) const {
StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 5ba9523504258..ff88f4aec98a5 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -2874,6 +2874,10 @@ void TextNodeDumper::VisitHLSLBufferDecl(const HLSLBufferDecl *D) {
dumpName(D);
}
+void TextNodeDumper::VisitHLSLOutArgExpr(const HLSLOutArgExpr *E) {
+ OS << (E->isInOut() ? " inout" : " out");
+}
+
void TextNodeDumper::VisitOpenACCConstructStmt(const OpenACCConstructStmt *S) {
OS << " " << S->getDirectiveKind();
}
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ffec3ef9d2269..b88e9b8f7f471 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -933,6 +933,10 @@ StringRef clang::getParameterABISpelling(ParameterABI ABI) {
return "swift_error_result";
case ParameterABI::SwiftIndirectResult:
return "swift_indirect_result";
+ case ParameterABI::HLSLOut:
+ return "out";
+ case ParameterABI::HLSLInOut:
+ return "inout";
}
llvm_unreachable("bad parameter ABI kind");
}
@@ -955,7 +959,17 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
if (EPI.isNoEscape())
OS << "__attribute__((noescape)) ";
auto ABI = EPI.getABI();
- if (ABI != ParameterABI::Ordinary)
+ if (ABI == ParameterABI::HLSLInOut || ABI == ParameterABI::HLSLOut) {
+ OS << getParameterABISpelling(ABI) << " ";
+ if (Policy.UseHLSLTypes) {
+ // This is a bit of a hack because we _do_ use reference types in the
+ // AST for representing inout and out parameters so that code
+ // generation is sane, but when re-printing these for HLSL we need to
+ // skip the reference.
+ print(T->getParamType(i).getNonReferenceType(), OS, StringRef());
+ continue;
+ }
+ } else if (ABI != ParameterABI::Ordinary)
OS << "__attribute__((" << getParameterABISpelling(ABI) << ")) ";
print(T->getParamType(i), OS, StringRef());
@@ -2023,10 +2037,6 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
case attr::ArmMveStrictPolymorphism:
OS << "__clang_arm_mve_strict_polymorphism";
break;
-
- // Nothing to print for this attribute.
- case attr::HLSLParamModifier:
- break;
}
OS << "))";
}
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 2f3dd5d01fa6c..4835c7ab3e894 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -2830,6 +2830,9 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
}
switch (FI.getExtParameterInfo(ArgNo).getABI()) {
+ case ParameterABI::HLSLOut:
+ case ParameterABI::HLSLInOut:
+ // FIXME: Do this...
case ParameterABI::Ordinary:
break;
@@ -4148,6 +4151,30 @@ static void emitWriteback(CodeGenFunction &CGF,
assert(!isProvablyNull(srcAddr.getBasePointer()) &&
"shouldn't have writeback for provably null argument");
+ if (CGF.getLangOpts().HLSL) {
+ if (writeback.CastExpr) {
+ RValue TmpVal = CGF.EmitAnyExprToTemp(writeback.CastExpr);
+ if (TmpVal.isScalar())
+ CGF.EmitStoreThroughLValue(TmpVal, srcLV);
+ else
+ CGF.EmitAggregateStore(srcLV.getPointer(CGF),
+ TmpVal.getAggregateAddress(), false);
+ } else {
+ if (srcLV.isSimple())
+ CGF.EmitAggregateStore(srcLV.getPointer(CGF), writeback.Temporary,
+ false);
+ else {
+ llvm::Value *value = CGF.Builder.CreateLoad(writeback.Temporary);
+ RValue TmpVal = RValue::get(value);
+ CGF.EmitStoreThroughLValue(TmpVal, srcLV);
+ }
+ }
+ if (writeback.LifetimeSz)
+ CGF.EmitLifetimeEnd(writeback.LifetimeSz,
+ writeback.Temporary.getBasePointer());
+ return;
+ }
+
llvm::BasicBlock *contBB = nullp...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/101083
More information about the cfe-commits
mailing list