[clang] [HLSL] Allow arrays to be returned by value in HLSL (PR #127896)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Mon Feb 24 10:06:56 PST 2025


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/127896

>From 362b64d31e5f70e4a26ea04c99a58fd5f5ca50ca Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Fri, 14 Feb 2025 12:59:56 -0800
Subject: [PATCH 1/4] Allow arrays to be returned by value in HLSL + test

---
 clang/lib/Sema/SemaExpr.cpp                   |  3 +-
 clang/lib/Sema/SemaType.cpp                   |  6 ++--
 .../BasicFeatures/ArrayReturn.hlsl            | 33 +++++++++++++++++++
 3 files changed, 39 insertions(+), 3 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/ArrayReturn.hlsl

diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index fad15bf95c415..f3a83642ed4e5 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20681,7 +20681,8 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if (DestType->isArrayType() || DestType->isFunctionType()) {
+  if ((DestType->isArrayType() && !S.getLangOpts().HLSL) ||
+      DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
       diagID = diag::err_block_returning_array_function;
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index db0177f9750e0..115d2431e020c 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2530,7 +2530,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if (T->isArrayType() || T->isFunctionType()) {
+  if ((T->isArrayType() && !getLangOpts().HLSL) || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
     return true;
@@ -4934,7 +4934,9 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
 
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
-      if (!D.isInvalidType() && (T->isArrayType() || T->isFunctionType()) &&
+      if (!D.isInvalidType() &&
+          ((T->isArrayType() && !S.getLangOpts().HLSL) ||
+           T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {
         unsigned diagID = diag::err_func_returning_array_function;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/ArrayReturn.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/ArrayReturn.hlsl
new file mode 100644
index 0000000000000..832c4ac9b10f5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/ArrayReturn.hlsl
@@ -0,0 +1,33 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -disable-llvm-passes -emit-llvm -finclude-default-header -o - %s | FileCheck %s
+
+typedef int Foo[2];
+
+// CHECK-LABEL: define void {{.*}}boop{{.*}}(ptr dead_on_unwind noalias writable sret([2 x i32]) align 4 %agg.result)
+// CHECK: [[G:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[G]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[AIB:%.*]] = getelementptr inbounds [2 x i32], ptr %agg.result, i32 0, i32 0
+// CHECK-NEXT: br label %arrayinit.body
+// CHECK: arrayinit.body:
+// CHECK-NEXT: [[AII:%.*]] = phi i32 [ 0, %entry ], [ %arrayinit.next, %arrayinit.body ]
+// CHECK-NEXT: [[X:%.*]] = getelementptr inbounds i32, ptr [[AIB]], i32 [[AII]]
+// CHECK-NEXT: [[AI:%.*]] = getelementptr inbounds nuw [2 x i32], ptr [[G]], i32 0, i32 [[AII]]
+// CHECK-NEXT: [[Y:%.*]] = load i32, ptr [[AI]], align 4
+// CHECK-NEXT: store i32 [[Y]], ptr [[X]], align 4
+// CHECK-NEXT: [[AIN:%.*]] = add nuw i32 [[AII]], 1
+// CHECK-NEXT: [[AID:%.*]] = icmp eq i32 [[AIN]], 2
+// CHECK-NEXT: br i1 [[AID]], label %arrayinit.end, label %arrayinit.body
+// CHECK: arrayinit.end:           
+// CHECK-NEXT: ret void
+export Foo boop() {
+  Foo G = {1,2};
+  return G;
+}
+
+// CHECK-LABEL: define void {{.*}}foo{{.*}}(ptr dead_on_unwind noalias writable sret([2 x i32]) align 4 %agg.result)
+// CHECK: store i32 1, ptr %agg.result, align 4
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds i32, ptr %agg.result, i32 1
+// CHECK-NEXT: store i32 2, ptr [[E]], align 4
+// CHECK-NEXT: ret void
+export int foo()[2] {
+  return {1,2};
+}

>From 2598db2039c1a8c93f088b928d669228bf6ff0ab Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Sun, 23 Feb 2025 15:06:56 -0800
Subject: [PATCH 2/4] address pr comments

---
 clang/include/clang/AST/ASTContext.h | 4 ++++
 clang/lib/Sema/SemaExpr.cpp          | 5 +++--
 clang/lib/Sema/SemaType.cpp          | 5 +++--
 3 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index d275873651786..d23b08721cfea 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -2803,6 +2803,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
     return getUnqualifiedArrayType(T, Quals);
   }
 
+  // Determine whether an array is a valid return type
+  // Array is a valid return type for HLSL
+  bool isReturnableArrayType() const { return getLangOpts().HLSL; }
+
   /// Determine whether the given types are equivalent after
   /// cvr-qualifiers have been removed.
   bool hasSameUnqualifiedType(QualType T1, QualType T2) const {
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index f3a83642ed4e5..91e71be6dc22c 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20681,7 +20681,7 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if ((DestType->isArrayType() && !S.getLangOpts().HLSL) ||
+  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType()) ||
       DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
@@ -20761,7 +20761,8 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
 
 ExprResult RebuildUnknownAnyExpr::VisitObjCMessageExpr(ObjCMessageExpr *E) {
   // Verify that this is a legal result type of a call.
-  if (DestType->isArrayType() || DestType->isFunctionType()) {
+  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType()) ||
+      DestType->isFunctionType()) {
     S.Diag(E->getExprLoc(), diag::err_func_returning_array_function)
       << DestType->isFunctionType() << DestType;
     return ExprError();
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 115d2431e020c..61805a1e708a2 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2530,7 +2530,8 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if ((T->isArrayType() && !getLangOpts().HLSL) || T->isFunctionType()) {
+  if ((T->isArrayType() && !Context.isReturnableArrayType()) ||
+      T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
     return true;
@@ -4935,7 +4936,7 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
       if (!D.isInvalidType() &&
-          ((T->isArrayType() && !S.getLangOpts().HLSL) ||
+          ((T->isArrayType() && !S.Context.isReturnableArrayType()) ||
            T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {

>From d9a96ac454268ae311a8b08e125855fb1c6f2756 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 24 Feb 2025 09:06:17 -0800
Subject: [PATCH 3/4] address pr comments round 2

---
 clang/include/clang/AST/ASTContext.h | 4 +++-
 clang/lib/Sema/SemaExpr.cpp          | 5 ++---
 clang/lib/Sema/SemaType.cpp          | 4 ++--
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index d23b08721cfea..1247191865ca1 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -2805,7 +2805,9 @@ class ASTContext : public RefCountedBase<ASTContext> {
 
   // Determine whether an array is a valid return type
   // Array is a valid return type for HLSL
-  bool isReturnableArrayType() const { return getLangOpts().HLSL; }
+  bool isReturnableArrayType(QualType T) const {
+    return T->isArrayType() && getLangOpts().HLSL;
+  }
 
   /// Determine whether the given types are equivalent after
   /// cvr-qualifiers have been removed.
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 91e71be6dc22c..ff41828f25ee0 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20681,7 +20681,7 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType()) ||
+  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType(DestType)) ||
       DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
@@ -20761,8 +20761,7 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
 
 ExprResult RebuildUnknownAnyExpr::VisitObjCMessageExpr(ObjCMessageExpr *E) {
   // Verify that this is a legal result type of a call.
-  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType()) ||
-      DestType->isFunctionType()) {
+  if (DestType->isArrayType() || DestType->isFunctionType()) {
     S.Diag(E->getExprLoc(), diag::err_func_returning_array_function)
       << DestType->isFunctionType() << DestType;
     return ExprError();
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 61805a1e708a2..2f496116dbede 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2530,7 +2530,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if ((T->isArrayType() && !Context.isReturnableArrayType()) ||
+  if ((T->isArrayType() && !Context.isReturnableArrayType(T)) ||
       T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
@@ -4936,7 +4936,7 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
       if (!D.isInvalidType() &&
-          ((T->isArrayType() && !S.Context.isReturnableArrayType()) ||
+          ((T->isArrayType() && !S.Context.isReturnableArrayType(T)) ||
            T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {

>From 01452ceeb7f3fc0e842d743858a5074fe70feba7 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 24 Feb 2025 10:05:18 -0800
Subject: [PATCH 4/4] final change

---
 clang/include/clang/AST/ASTContext.h    | 6 ------
 clang/include/clang/Basic/LangOptions.h | 2 ++
 clang/lib/Sema/SemaExpr.cpp             | 2 +-
 clang/lib/Sema/SemaType.cpp             | 4 ++--
 4 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 1247191865ca1..d275873651786 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -2803,12 +2803,6 @@ class ASTContext : public RefCountedBase<ASTContext> {
     return getUnqualifiedArrayType(T, Quals);
   }
 
-  // Determine whether an array is a valid return type
-  // Array is a valid return type for HLSL
-  bool isReturnableArrayType(QualType T) const {
-    return T->isArrayType() && getLangOpts().HLSL;
-  }
-
   /// Determine whether the given types are equivalent after
   /// cvr-qualifiers have been removed.
   bool hasSameUnqualifiedType(QualType T1, QualType T2) const {
diff --git a/clang/include/clang/Basic/LangOptions.h b/clang/include/clang/Basic/LangOptions.h
index 651b3b67b1058..34d75aff43fab 100644
--- a/clang/include/clang/Basic/LangOptions.h
+++ b/clang/include/clang/Basic/LangOptions.h
@@ -816,6 +816,8 @@ class LangOptions : public LangOptionsBase {
            VisibilityForcedKinds::ForceHidden;
   }
 
+  bool allowArrayReturnTypes() const { return HLSL; }
+
   /// Remap path prefix according to -fmacro-prefix-path option.
   void remapPathPrefix(SmallVectorImpl<char> &Path) const;
 
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff41828f25ee0..be35a0e45a636 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20681,7 +20681,7 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if ((DestType->isArrayType() && !S.Context.isReturnableArrayType(DestType)) ||
+  if ((DestType->isArrayType() && !S.getLangOpts().allowArrayReturnTypes()) ||
       DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 2f496116dbede..60096eebfdb6f 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2530,7 +2530,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if ((T->isArrayType() && !Context.isReturnableArrayType(T)) ||
+  if ((T->isArrayType() && !getLangOpts().allowArrayReturnTypes()) ||
       T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
@@ -4936,7 +4936,7 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
       if (!D.isInvalidType() &&
-          ((T->isArrayType() && !S.Context.isReturnableArrayType(T)) ||
+          ((T->isArrayType() && !S.getLangOpts().allowArrayReturnTypes()) ||
            T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {



More information about the cfe-commits mailing list