[clang] [llvm] [RISCV][VLS] Support RISCV VLS calling convention (PR #100346)

via cfe-commits cfe-commits at lists.llvm.org
Wed Jul 24 04:12:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-debuginfo

@llvm/pr-subscribers-clang-modules

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

This patch adds a function attribute `riscv_vls_cc` for RISCV VLS calling
convention which takes 0 or 1 argument, the argument is the `ABI_VLEN`
which is the `VLEN` for passing the fixed-vector arguments, it wraps the
argument as a scalable vector(VLA) using the `ABI_VLEN` and uses the
corresponding mechanism to handle it. The range of `ABI_VLEN` is [32, 65536],
if not specified, the default value is 128.

An option `-mriscv-abi-vlen=N` is also added to specify the `ABI_VLEN`
globally, it's used for every functions are being compiled, however if
both function attribute and option are specified, the function attribute
has higher priority than the option which means the function attribute
overwrites the `ABI_VLEN` specified by the option.

Here is an example of VLS argument passing:
Non-VLS call:
```
  void original_call(__attribute__((vector_size(16))) int arg) {}
=>
  define void @<!-- -->original_call(i128 noundef %arg) {
  entry:
    ...
    ret void
  }
```
VLS call:
```
  void __attribute__((riscv_vls_cc(256))) vls_call(__attribute__((vector_size(16))) int arg) {}
=>
  define riscv_vls_cc void @<!-- -->vls_call(<vscale x 1 x i32> %arg) {
  entry:
    ...
    ret void
  }
}
```

The first Non-VLS call passes generic vector argument of 16 bytes by
flattened integer.
On the contrary, the VLS call uses `ABI_VLEN=256` which wraps the
vector to <vscale x 1 x i32> where the number of scalable vector elements
is calaulated by: `ORIG_ELTS * RVV_BITS_PER_BLOCK / ABI_VLEN`.
Note: ORIG_ELTS = Vector Size / Type Size = 128 / 32 = 4.


---

Patch is 38.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100346.diff


33 Files Affected:

- (modified) clang/include/clang-c/Index.h (+1) 
- (modified) clang/include/clang/AST/Type.h (+21-5) 
- (modified) clang/include/clang/AST/TypeProperties.td (+5-2) 
- (modified) clang/include/clang/Basic/Attr.td (+8) 
- (modified) clang/include/clang/Basic/AttrDocs.td (+11) 
- (modified) clang/include/clang/Basic/Specifiers.h (+1) 
- (modified) clang/include/clang/CodeGen/CGFunctionInfo.h (+8-1) 
- (modified) clang/include/clang/Driver/Options.td (+2) 
- (modified) clang/lib/AST/ASTContext.cpp (+2) 
- (modified) clang/lib/AST/ItaniumMangle.cpp (+1) 
- (modified) clang/lib/AST/Type.cpp (+2) 
- (modified) clang/lib/AST/TypePrinter.cpp (+6) 
- (modified) clang/lib/Basic/Targets/RISCV.cpp (+1) 
- (modified) clang/lib/CodeGen/CGCall.cpp (+5) 
- (modified) clang/lib/CodeGen/CGDebugInfo.cpp (+2) 
- (modified) clang/lib/CodeGen/Targets/RISCV.cpp (+47-26) 
- (modified) clang/lib/Driver/ToolChains/Arch/RISCV.cpp (+4) 
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+26-4) 
- (modified) clang/lib/Sema/SemaType.cpp (+16-1) 
- (modified) clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c (+24) 
- (modified) clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp (+14) 
- (modified) clang/test/CodeGen/RISCV/riscv-vector-callingconv.c (+16) 
- (modified) clang/test/CodeGen/RISCV/riscv-vector-callingconv.cpp (+17) 
- (modified) clang/tools/libclang/CXType.cpp (+1) 
- (modified) llvm/include/llvm/AsmParser/LLToken.h (+1) 
- (modified) llvm/include/llvm/BinaryFormat/Dwarf.def (+1) 
- (modified) llvm/include/llvm/IR/CallingConv.h (+3) 
- (modified) llvm/lib/AsmParser/LLLexer.cpp (+1) 
- (modified) llvm/lib/AsmParser/LLParser.cpp (+4) 
- (modified) llvm/lib/IR/AsmWriter.cpp (+3) 
- (modified) llvm/lib/Target/RISCV/RISCVFeatures.td (+9) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+1) 
- (modified) llvm/lib/Target/RISCV/RISCVSubtarget.h (+1) 


``````````diff
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 115f5ab090f96..159f21846fc3b 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -3005,6 +3005,7 @@ enum CXCallingConv {
   CXCallingConv_M68kRTD = 19,
   CXCallingConv_PreserveNone = 20,
   CXCallingConv_RISCVVectorCall = 21,
+  CXCallingConv_RISCVVLSCall = 22,
 
   CXCallingConv_Invalid = 100,
   CXCallingConv_Unexposed = 200
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 25defea58c2dc..d1c6e629e296c 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -1942,7 +1942,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     /// Extra information which affects how the function is called, like
     /// regparm and the calling convention.
     LLVM_PREFERRED_TYPE(CallingConv)
-    unsigned ExtInfo : 13;
+    unsigned ExtInfo : 17;
 
     /// The ref-qualifier associated with a \c FunctionProtoType.
     ///
@@ -4395,6 +4395,8 @@ class FunctionType : public Type {
 
     // |  CC  |noreturn|produces|nocallersavedregs|regparm|nocfcheck|cmsenscall|
     // |0 .. 4|   5    |    6   |       7         |8 .. 10|    11   |    12    |
+    // |RISCV-ABI-VLEN|
+    // |13    ..    17|
     //
     // regparm is either 0 (no regparm attribute) or the regparm value+1.
     enum { CallConvMask = 0x1F };
@@ -4407,23 +4409,25 @@ class FunctionType : public Type {
     };
     enum { NoCfCheckMask = 0x800 };
     enum { CmseNSCallMask = 0x1000 };
-    uint16_t Bits = CC_C;
+    enum { Log2RISCVABIVLenMask = 0x1E000, Log2RISCVABIVLenOffset = 13 };
+    uint32_t Bits = CC_C;
 
-    ExtInfo(unsigned Bits) : Bits(static_cast<uint16_t>(Bits)) {}
+    ExtInfo(unsigned Bits) : Bits(static_cast<uint32_t>(Bits)) {}
 
   public:
     // Constructor with no defaults. Use this when you know that you
     // have all the elements (when reading an AST file for example).
     ExtInfo(bool noReturn, bool hasRegParm, unsigned regParm, CallingConv cc,
             bool producesResult, bool noCallerSavedRegs, bool NoCfCheck,
-            bool cmseNSCall) {
+            bool cmseNSCall, unsigned Log2RISCVABIVLen) {
       assert((!hasRegParm || regParm < 7) && "Invalid regparm value");
       Bits = ((unsigned)cc) | (noReturn ? NoReturnMask : 0) |
              (producesResult ? ProducesResultMask : 0) |
              (noCallerSavedRegs ? NoCallerSavedRegsMask : 0) |
              (hasRegParm ? ((regParm + 1) << RegParmOffset) : 0) |
              (NoCfCheck ? NoCfCheckMask : 0) |
-             (cmseNSCall ? CmseNSCallMask : 0);
+             (cmseNSCall ? CmseNSCallMask : 0) |
+             (Log2RISCVABIVLen << Log2RISCVABIVLenOffset);
     }
 
     // Constructor with all defaults. Use when for example creating a
@@ -4450,6 +4454,10 @@ class FunctionType : public Type {
 
     CallingConv getCC() const { return CallingConv(Bits & CallConvMask); }
 
+    unsigned getLog2RISCVABIVLen() const {
+      return (Bits & Log2RISCVABIVLenMask) >> Log2RISCVABIVLenOffset;
+    }
+
     bool operator==(ExtInfo Other) const {
       return Bits == Other.Bits;
     }
@@ -4505,6 +4513,11 @@ class FunctionType : public Type {
       return ExtInfo((Bits & ~CallConvMask) | (unsigned) cc);
     }
 
+    ExtInfo withLog2RISCVABIVLen(unsigned Log2RISCVABIVLen) const {
+      return ExtInfo((Bits & ~Log2RISCVABIVLenMask) |
+                     (Log2RISCVABIVLen << Log2RISCVABIVLenOffset));
+    }
+
     void Profile(llvm::FoldingSetNodeID &ID) const {
       ID.AddInteger(Bits);
     }
@@ -4609,6 +4622,9 @@ class FunctionType : public Type {
 
   bool getCmseNSCallAttr() const { return getExtInfo().getCmseNSCall(); }
   CallingConv getCallConv() const { return getExtInfo().getCC(); }
+  unsigned getLog2RISCVABIVLen() const {
+    return getExtInfo().getLog2RISCVABIVLen();
+  }
   ExtInfo getExtInfo() const { return ExtInfo(FunctionTypeBits.ExtInfo); }
 
   static_assert((~Qualifiers::FastMask & Qualifiers::CVRMask) == 0,
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 7d4353c2773a3..66bff0f879b56 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -313,6 +313,9 @@ let Class = FunctionType in {
   def : Property<"cmseNSCall", Bool> {
     let Read = [{ node->getExtInfo().getCmseNSCall() }];
   }
+  def : Property<"Log2RISCVABIVLen", UInt32> {
+    let Read = [{ node->getExtInfo().getLog2RISCVABIVLen() }];
+  }
 }
 
 let Class = FunctionNoProtoType in {
@@ -320,7 +323,7 @@ let Class = FunctionNoProtoType in {
     auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm,
                                          callingConvention, producesResult,
                                          noCallerSavedRegs, noCfCheck,
-                                         cmseNSCall);
+                                         cmseNSCall, Log2RISCVABIVLen);
     return ctx.getFunctionNoProtoType(returnType, extInfo);
   }]>;
 }
@@ -363,7 +366,7 @@ let Class = FunctionProtoType in {
     auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm,
                                          callingConvention, producesResult,
                                          noCallerSavedRegs, noCfCheck,
-                                         cmseNSCall);
+                                         cmseNSCall, Log2RISCVABIVLen);
     FunctionProtoType::ExtProtoInfo epi;
     epi.ExtInfo = extInfo;
     epi.Variadic = variadic;
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 4825979a974d2..ec2c1bedaef50 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3139,6 +3139,14 @@ def RISCVVectorCC: DeclOrTypeAttr, TargetSpecificAttr<TargetRISCV> {
  let Documentation = [RISCVVectorCCDocs];
 }
 
+def RISCVVLSCC: DeclOrTypeAttr, TargetSpecificAttr<TargetRISCV> {
+ let Spellings = [CXX11<"riscv", "vls_cc">,
+                  C23<"riscv", "vls_cc">,
+                  Clang<"riscv_vls_cc">];
+ let Args = [UnsignedArgument<"VectorWidth", /*opt*/1>];
+ let Documentation = [RISCVVLSCCDocs];
+}
+
 def Target : InheritableAttr {
   let Spellings = [GCC<"target">];
   let Args = [StringArgument<"featuresStr">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 99738812c8157..1eba3b2945f7b 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -5554,6 +5554,17 @@ them if they use them.
  }];
 }
 
+def RISCVVLSCCDocs : Documentation {
+ let Category = DocCatCallingConvs;
+ let Heading = "riscv::vls_cc, riscv_vls_cc, clang::riscv_vls_cc";
+ let Content = [{
+The ``riscv_vls_cc`` attribute can be applied to a function. Functions
+declared with this attribute will utilize the standard fixed-length vector
+calling convention variant instead of the default calling convention defined by
+the ABI. This variant aims to pass fixed-length vectors via vector registers,
+if possible, rather than through general-purpose registers.}];
+}
+
 def PreferredNameDocs : Documentation {
   let Category = DocCatDecl;
   let Content = [{
diff --git a/clang/include/clang/Basic/Specifiers.h b/clang/include/clang/Basic/Specifiers.h
index fb11e8212f8b6..81b0b856c33d0 100644
--- a/clang/include/clang/Basic/Specifiers.h
+++ b/clang/include/clang/Basic/Specifiers.h
@@ -297,6 +297,7 @@ namespace clang {
     CC_M68kRTD,           // __attribute__((m68k_rtd))
     CC_PreserveNone,      // __attribute__((preserve_none))
     CC_RISCVVectorCall,   // __attribute__((riscv_vector_cc))
+    CC_RISCVVLSCall,      // __attribute__((riscv_vls_cc))
   };
 
   /// Checks whether the given calling convention supports variadic
diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index 811f33407368c..aae13d77d9050 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -608,6 +608,9 @@ class CGFunctionInfo final
   /// Log 2 of the maximum vector width.
   unsigned MaxVectorWidth : 4;
 
+  /// Log2 of ABI_VLEN used in RISCV VLS calling convention.
+  unsigned Log2RISCVABIVLen : 4;
+
   RequiredArgs Required;
 
   /// The struct representing all arguments passed in memory.  Only used when
@@ -718,11 +721,13 @@ class CGFunctionInfo final
   bool getHasRegParm() const { return HasRegParm; }
   unsigned getRegParm() const { return RegParm; }
 
+  unsigned getLog2RISCVABIVLen() const { return Log2RISCVABIVLen; }
+
   FunctionType::ExtInfo getExtInfo() const {
     return FunctionType::ExtInfo(isNoReturn(), getHasRegParm(), getRegParm(),
                                  getASTCallingConvention(), isReturnsRetained(),
                                  isNoCallerSavedRegs(), isNoCfCheck(),
-                                 isCmseNSCall());
+                                 isCmseNSCall(), getLog2RISCVABIVLen());
   }
 
   CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
@@ -776,6 +781,7 @@ class CGFunctionInfo final
     ID.AddInteger(RegParm);
     ID.AddBoolean(NoCfCheck);
     ID.AddBoolean(CmseNSCall);
+    ID.AddInteger(Log2RISCVABIVLen);
     ID.AddInteger(Required.getOpaqueData());
     ID.AddBoolean(HasExtParameterInfos);
     if (HasExtParameterInfos) {
@@ -803,6 +809,7 @@ class CGFunctionInfo final
     ID.AddInteger(info.getRegParm());
     ID.AddBoolean(info.getNoCfCheck());
     ID.AddBoolean(info.getCmseNSCall());
+    ID.AddInteger(info.getLog2RISCVABIVLen());
     ID.AddInteger(required.getOpaqueData());
     ID.AddBoolean(!paramInfos.empty());
     if (!paramInfos.empty()) {
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index fa36405ec1bdd..aafbf9eec786f 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4865,6 +4865,8 @@ def mrvv_vector_bits_EQ : Joined<["-"], "mrvv-vector-bits=">, Group<m_Group>,
       !eq(GlobalDocumentation.Program, "Flang") : "",
       true: " The value will be reflected in __riscv_v_fixed_vlen preprocessor define"),
     " (RISC-V only)")>;
+def mriscv_abi_vlen_EQ : Joined<["-"], "mriscv-abi-vlen=">, Group<m_Group>,
+                         HelpText<"Specify the VLEN for VLS calling convention.">;
 
 def munaligned_access : Flag<["-"], "munaligned-access">, Group<m_Group>,
   HelpText<"Allow memory accesses to be unaligned (AArch32/MIPSr6 only)">;
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 7af9ea7105bb0..8369b590809d6 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -10825,6 +10825,8 @@ QualType ASTContext::mergeFunctionTypes(QualType lhs, QualType rhs,
     return {};
   if (lbaseInfo.getNoCfCheck() != rbaseInfo.getNoCfCheck())
     return {};
+  if (lbaseInfo.getLog2RISCVABIVLen() != rbaseInfo.getLog2RISCVABIVLen())
+    return {};
 
   // When merging declarations, it's common for supplemental information like
   // attributes to only be present in one of the declarations, and we generally
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index d46d621d4c7d4..ba8f2a4c6776b 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3452,6 +3452,7 @@ StringRef CXXNameMangler::getCallingConvQualifierName(CallingConv CC) {
   case CC_M68kRTD:
   case CC_PreserveNone:
   case CC_RISCVVectorCall:
+  case CC_RISCVVLSCall:
     // FIXME: we should be mangling all of the above.
     return "";
 
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index fdaab8e434593..7e2ffb09e340a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -3510,6 +3510,7 @@ StringRef FunctionType::getNameForCallConv(CallingConv CC) {
     // clang-format off
   case CC_RISCVVectorCall: return "riscv_vector_cc";
     // clang-format on
+  case CC_RISCVVLSCall: return "riscv_vls_cc";
   }
 
   llvm_unreachable("Invalid calling convention.");
@@ -4162,6 +4163,7 @@ bool AttributedType::isCallingConv() const {
   case attr::M68kRTD:
   case attr::PreserveNone:
   case attr::RISCVVectorCC:
+  case attr::RISCVVLSCC:
     return true;
   }
   llvm_unreachable("invalid attr kind");
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ffec3ef9d2269..1a66843f7600d 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1114,6 +1114,9 @@ void TypePrinter::printFunctionAfter(const FunctionType::ExtInfo &Info,
     case CC_RISCVVectorCall:
       OS << "__attribute__((riscv_vector_cc))";
       break;
+    case CC_RISCVVLSCall:
+      OS << "__attribute__((riscv_vls_cc))";
+      break;
     }
   }
 
@@ -2014,6 +2017,9 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::RISCVVectorCC:
     OS << "riscv_vector_cc";
     break;
+  case attr::RISCVVLSCC:
+    OS << "riscv_vls_cc";
+    break;
   case attr::NoDeref:
     OS << "noderef";
     break;
diff --git a/clang/lib/Basic/Targets/RISCV.cpp b/clang/lib/Basic/Targets/RISCV.cpp
index 41d836330b38c..7b649f05f0aa9 100644
--- a/clang/lib/Basic/Targets/RISCV.cpp
+++ b/clang/lib/Basic/Targets/RISCV.cpp
@@ -476,6 +476,7 @@ RISCVTargetInfo::checkCallingConvention(CallingConv CC) const {
     return CCCR_Warning;
   case CC_C:
   case CC_RISCVVectorCall:
+  case CC_RISCVVLSCall:
     return CCCR_OK;
   }
 }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 234a9c16e39df..e6e05ee92ac38 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -77,6 +77,7 @@ unsigned CodeGenTypes::ClangCallConvToLLVMCallConv(CallingConv CC) {
     // clang-format off
   case CC_RISCVVectorCall: return llvm::CallingConv::RISCV_VectorCall;
     // clang-format on
+  case CC_RISCVVLSCall: return llvm::CallingConv::RISCV_VLSCall;
   }
 }
 
@@ -266,6 +267,9 @@ static CallingConv getCallingConventionForDecl(const ObjCMethodDecl *D,
   if (D->hasAttr<RISCVVectorCCAttr>())
     return CC_RISCVVectorCall;
 
+  if (D->hasAttr<RISCVVLSCCAttr>())
+    return CC_RISCVVLSCall;
+
   return CC_C;
 }
 
@@ -862,6 +866,7 @@ CGFunctionInfo *CGFunctionInfo::create(unsigned llvmCC, bool instanceMethod,
   FI->HasExtParameterInfos = !paramInfos.empty();
   FI->getArgsBuffer()[0].type = resultType;
   FI->MaxVectorWidth = 0;
+  FI->Log2RISCVABIVLen = info.getLog2RISCVABIVLen();
   for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
     FI->getArgsBuffer()[i + 1].type = argTypes[i];
   for (unsigned i = 0, e = paramInfos.size(); i != e; ++i)
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index 3d8a715b692de..d437688fb577c 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -1554,6 +1554,8 @@ static unsigned getDwarfCC(CallingConv CC) {
     return llvm::dwarf::DW_CC_LLVM_PreserveNone;
   case CC_RISCVVectorCall:
     return llvm::dwarf::DW_CC_LLVM_RISCVVectorCall;
+  case CC_RISCVVLSCall:
+    return llvm::dwarf::DW_CC_LLVM_RISCVVectorCall;
   }
   return 0;
 }
diff --git a/clang/lib/CodeGen/Targets/RISCV.cpp b/clang/lib/CodeGen/Targets/RISCV.cpp
index f2add9351c03c..4d16eaad781dc 100644
--- a/clang/lib/CodeGen/Targets/RISCV.cpp
+++ b/clang/lib/CodeGen/Targets/RISCV.cpp
@@ -8,6 +8,7 @@
 
 #include "ABIInfoImpl.h"
 #include "TargetInfo.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
 
 using namespace clang;
 using namespace clang::CodeGen;
@@ -45,8 +46,8 @@ class RISCVABIInfo : public DefaultABIInfo {
   void computeInfo(CGFunctionInfo &FI) const override;
 
   ABIArgInfo classifyArgumentType(QualType Ty, bool IsFixed, int &ArgGPRsLeft,
-                                  int &ArgFPRsLeft) const;
-  ABIArgInfo classifyReturnType(QualType RetTy) const;
+                                  int &ArgFPRsLeft, unsigned ABIVLen) const;
+  ABIArgInfo classifyReturnType(QualType RetTy, unsigned ABIVLen) const;
 
   RValue EmitVAArg(CodeGenFunction &CGF, Address VAListAddr, QualType Ty,
                    AggValueSlot Slot) const override;
@@ -62,14 +63,23 @@ class RISCVABIInfo : public DefaultABIInfo {
                                                llvm::Type *Field2Ty,
                                                CharUnits Field2Off) const;
 
-  ABIArgInfo coerceVLSVector(QualType Ty) const;
+  ABIArgInfo coerceVLSVector(QualType Ty, unsigned ABIVLen = 0) const;
 };
 } // end anonymous namespace
 
 void RISCVABIInfo::computeInfo(CGFunctionInfo &FI) const {
+  unsigned ABIVLen = 1 << FI.getExtInfo().getLog2RISCVABIVLen();
+  if (ABIVLen == 1)
+    // No riscv_vls_cc in the function, check if there's one passed from
+    // compiler options.
+    for (unsigned i = 5; i <= 16; ++i)
+      if (getContext().getTargetInfo().getTargetOpts().FeatureMap.contains(
+              "abi-vlen-" + llvm::utostr(1 << i) + "b"))
+        ABIVLen = 1 << i;
+
   QualType RetTy = FI.getReturnType();
   if (!getCXXABI().classifyReturnType(FI))
-    FI.getReturnInfo() = classifyReturnType(RetTy);
+    FI.getReturnInfo() = classifyReturnType(RetTy, ABIVLen);
 
   // IsRetIndirect is true if classifyArgumentType indicated the value should
   // be passed indirect, or if the type size is a scalar greater than 2*XLen
@@ -96,7 +106,7 @@ void RISCVABIInfo::computeInfo(CGFunctionInfo &FI) const {
   for (auto &ArgInfo : FI.arguments()) {
     bool IsFixed = ArgNum < NumFixedArgs;
     ArgInfo.info =
-        classifyArgumentType(ArgInfo.type, IsFixed, ArgGPRsLeft, ArgFPRsLeft);
+        classifyArgumentType(ArgInfo.type, IsFixed, ArgGPRsLeft, ArgFPRsLeft, ABIVLen);
     ArgNum++;
   }
 }
@@ -317,38 +327,44 @@ ABIArgInfo RISCVABIInfo::coerceAndExpandFPCCEligibleStruct(
 
 // Fixed-length RVV vectors are represented as scalable vectors in function
 // args/return and must be coerced from fixed vectors.
-ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty) const {
+ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty, unsigned ABIVLen) const {
   assert(Ty->isVectorType() && "expected vector type!");
 
   const auto *VT = Ty->castAs<VectorType>();
   assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  auto VScale =
-      getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts());
-
   unsigned NumElts = VT->getNumElements();
-  llvm::Type *EltType;
-  if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
-    NumElts *= 8;
-    EltType = llvm::Type::getInt1Ty(getVMContext());
+  llvm::ScalableVectorType *ResType;
+  llvm::Type *EltType = CGT.ConvertType(VT->getElementType());;
+
+  if (ABIVLen == 0) {
+    // RVV fixed-length vector
+    auto VScale =
+        getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts());
+
+    if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
+      NumElts *= 8;
+      EltType = llvm::Type::getInt1Ty(getVMContext());
+    }
+
+    // The MinNumElts is simplified from equation:
+    // NumElts / VScale =
+    //  (EltSize * NumElts / (VScale * RVVBitsPerBlock))
+    //    * (RVVBitsPerBlock / EltSize)
+    ResType = llvm::ScalableVectorType::get(EltType, NumElts / VScale->first);
   } else {
-    assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
-           "Unexpected vector kind");
-    EltType = CGT.ConvertType(VT->getElementType());
+    // Generic vector
+    ResType = llvm::ScalableVectorType::get(
+        EltType, NumElts * llvm::RISCV::RVVBitsPerBlock / ABIVLen);
   }
 
-  // The MinNumElts is simplified from equation:
-  // NumElts / VScale =
-  //  (EltSize * NumElts / (VScale * RVVBitsPerBlock))
-  //    * (RVVBitsPerBlock / EltSize)
-  llvm::ScalableVectorType *ResType =
-      llvm::ScalableVectorType::get(EltType, NumElts / VScale->first);
   return ABIArgInfo::getDirect(ResType);
 }
 
 ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
                                               int &ArgGPRsLeft,
-                                              int &ArgFPRsLeft) const {
+                                              int &ArgFPRsLeft,
+                                              unsigned ABIVLen) const {
   assert(ArgGPRsLeft <= NumArgGPRs && "Arg GPR tracking underflow");
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
@@ -451,10 +467,15 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
     return Info;
   }
 
-  if (const Vect...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/100346


More information about the cfe-commits mailing list