[llvm-branch-commits] [clang] 08dbbaf - [RISCV][NFC] Refactor RISC-V vector intrinsic utils.
Kito Cheng via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 11 02:57:51 PDT 2022
Author: Kito Cheng
Date: 2022-05-11T17:56:59+08:00
New Revision: 08dbbaf68d88a57e977d0674bddd0142e5d1e0b9
URL: https://github.com/llvm/llvm-project/commit/08dbbaf68d88a57e977d0674bddd0142e5d1e0b9
DIFF: https://github.com/llvm/llvm-project/commit/08dbbaf68d88a57e977d0674bddd0142e5d1e0b9.diff
LOG: [RISCV][NFC] Refactor RISC-V vector intrinsic utils.
This patch is preparation for D111617, use class/struct/enum rather than
char/StringRef to present internal information as possible, that provide
more compact way to store those info and also easier to
serialize/deserialize.
And also that improve readability of the code, e.g. "v" vs
TypeProfile::Vector.
Differential Revision: https://reviews.llvm.org/D124730
Added:
Modified:
clang/include/clang/Support/RISCVVIntrinsicUtils.h
clang/lib/Support/RISCVVIntrinsicUtils.cpp
clang/utils/TableGen/RISCVVEmitter.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
index 1a4947d0c3df3..ddd46fe1727c9 100644
--- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -9,7 +9,10 @@
#ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
#define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string>
@@ -18,9 +21,128 @@
namespace clang {
namespace RISCV {
-using BasicType = char;
using VScaleVal = llvm::Optional<unsigned>;
+// Modifier for vector type.
+enum class VectorTypeModifier : uint8_t {
+ NoModifier,
+ Log2EEW3,
+ Log2EEW4,
+ Log2EEW5,
+ Log2EEW6,
+ FixedSEW8,
+ FixedSEW16,
+ FixedSEW32,
+ FixedSEW64,
+ LFixedLog2LMULN3,
+ LFixedLog2LMULN2,
+ LFixedLog2LMULN1,
+ LFixedLog2LMUL0,
+ LFixedLog2LMUL1,
+ LFixedLog2LMUL2,
+ LFixedLog2LMUL3,
+ SFixedLog2LMULN3,
+ SFixedLog2LMULN2,
+ SFixedLog2LMULN1,
+ SFixedLog2LMUL0,
+ SFixedLog2LMUL1,
+ SFixedLog2LMUL2,
+ SFixedLog2LMUL3,
+};
+
+// Similar to basic type but used to describe what's kind of type related to
+// basic vector type, used to compute type info of arguments.
+enum class PrimitiveType : uint8_t {
+ Invalid,
+ Scalar,
+ Vector,
+ Widening2XVector,
+ Widening4XVector,
+ Widening8XVector,
+ MaskVector,
+ Void,
+ SizeT,
+ Ptr
diff ,
+ UnsignedLong,
+ SignedLong,
+};
+
+// Modifier for type, used for both scalar and vector types.
+enum class TypeModifier : uint8_t {
+ NoModifier = 0,
+ Pointer = 1 << 0,
+ Const = 1 << 1,
+ Immediate = 1 << 2,
+ UnsignedInteger = 1 << 3,
+ SignedInteger = 1 << 4,
+ Float = 1 << 5,
+ LMUL1 = 1 << 6,
+ MaxOffset = 6,
+ LLVM_MARK_AS_BITMASK_ENUM(LMUL1),
+};
+
+// TypeProfile is used to compute type info of arguments or return value.
+struct TypeProfile {
+ constexpr TypeProfile() = default;
+ constexpr TypeProfile(PrimitiveType PT) : PT(static_cast<uint8_t>(PT)) {}
+ constexpr TypeProfile(PrimitiveType PT, TypeModifier TM)
+ : PT(static_cast<uint8_t>(PT)), TM(static_cast<uint8_t>(TM)) {}
+ constexpr TypeProfile(uint8_t PT, uint8_t VTM, uint8_t TM)
+ : PT(PT), VTM(VTM), TM(TM) {}
+
+ uint8_t PT = static_cast<uint8_t>(PrimitiveType::Invalid);
+ uint8_t VTM = static_cast<uint8_t>(VectorTypeModifier::NoModifier);
+ uint8_t TM = static_cast<uint8_t>(TypeModifier::NoModifier);
+
+ std::string IndexStr() const {
+ return std::to_string(PT) + "_" + std::to_string(VTM) + "_" +
+ std::to_string(TM);
+ };
+
+ bool operator!=(const TypeProfile &TP) const {
+ return TP.PT != PT || TP.VTM != VTM || TP.TM != TM;
+ }
+ bool operator>(const TypeProfile &TP) const {
+ return !(TP.PT <= PT && TP.VTM <= VTM && TP.TM <= TM);
+ }
+
+ static const TypeProfile Mask;
+ static const TypeProfile Vector;
+ static const TypeProfile VL;
+ static llvm::Optional<TypeProfile>
+ parseTypeProfile(llvm::StringRef PrototypeStr);
+};
+
+llvm::SmallVector<TypeProfile> parsePrototypes(llvm::StringRef Prototypes);
+
+// Basic type of vector type.
+enum class BasicType : uint8_t {
+ Unknown = 0,
+ Int8 = 1 << 0,
+ Int16 = 1 << 1,
+ Int32 = 1 << 2,
+ Int64 = 1 << 3,
+ Float16 = 1 << 4,
+ Float32 = 1 << 5,
+ Float64 = 1 << 6,
+ MaxOffset = 6,
+ LLVM_MARK_AS_BITMASK_ENUM(Float64),
+};
+
+// Type of vector type.
+enum ScalarTypeKind : uint8_t {
+ Void,
+ Size_t,
+ Ptr
diff _t,
+ UnsignedLong,
+ SignedLong,
+ Boolean,
+ SignedInteger,
+ UnsignedInteger,
+ Float,
+ Invalid,
+};
+
// Exponential LMUL
struct LMULType {
int Log2LMUL;
@@ -32,20 +154,12 @@ struct LMULType {
LMULType &operator*=(uint32_t RHS);
};
+class RVVType;
+using RVVTypePtr = RVVType *;
+using RVVTypes = std::vector<RVVTypePtr>;
+
// This class is compact representation of a valid and invalid RVVType.
class RVVType {
- enum ScalarTypeKind : uint32_t {
- Void,
- Size_t,
- Ptr
diff _t,
- UnsignedLong,
- SignedLong,
- Boolean,
- SignedInteger,
- UnsignedInteger,
- Float,
- Invalid,
- };
BasicType BT;
ScalarTypeKind ScalarType = Invalid;
LMULType LMUL;
@@ -64,8 +178,8 @@ class RVVType {
std::string ShortStr;
public:
- RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {}
- RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype);
+ RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {}
+ RVVType(BasicType BT, int Log2LMUL, const TypeProfile &Profile);
// Return the string representation of a type, which is an encoded string for
// passing to the BUILTIN() macro in Builtins.def.
@@ -114,7 +228,11 @@ class RVVType {
// Applies a prototype modifier to the current type. The result maybe an
// invalid type.
- void applyModifier(llvm::StringRef prototype);
+ void applyModifier(const TypeProfile &prototype);
+
+ void applyLog2EEW(unsigned Log2EEW);
+ void applyFixedSEW(unsigned NewSEW);
+ void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan);
// Compute and record a string for legal type.
void initBuiltinStr();
@@ -124,10 +242,19 @@ class RVVType {
void initTypeStr();
// Compute and record a short name of a type for C/C++ name suffix.
void initShortStr();
+
+public:
+ /// Compute output and input types by applying
diff erent config (basic type
+ /// and LMUL with type transformers). It also record result of type in legal
+ /// or illegal set to avoid compute the same config again. The result maybe
+ /// have illegal RVVType.
+ static llvm::Optional<RVVTypes>
+ computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
+ llvm::ArrayRef<TypeProfile> PrototypeSeq);
+ static llvm::Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL,
+ TypeProfile Proto);
};
-using RVVTypePtr = RVVType *;
-using RVVTypes = std::vector<RVVTypePtr>;
using RISCVPredefinedMacroT = uint8_t;
enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
@@ -206,6 +333,10 @@ class RVVIntrinsic {
// Return the type string for a BUILTIN() macro in Builtins.def.
std::string getBuiltinTypeStr() const;
+
+ static std::string
+ getSuffixStr(BasicType Type, int Log2LMUL,
+ const llvm::SmallVector<TypeProfile> &TypeProfiles);
};
} // end namespace RISCV
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
index 2e2f92d4804fe..0f21aa113eec3 100644
--- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -22,6 +22,14 @@ using namespace llvm;
namespace clang {
namespace RISCV {
+const TypeProfile TypeProfile::Mask = TypeProfile(PrimitiveType::MaskVector);
+const TypeProfile TypeProfile::VL = TypeProfile(PrimitiveType::SizeT);
+const TypeProfile TypeProfile::Vector = TypeProfile(PrimitiveType::Vector);
+
+// Concat BasicType, LMUL and Proto as key
+static StringMap<RVVType> LegalTypes;
+static StringSet<> IllegalTypes;
+
//===----------------------------------------------------------------------===//
// Type implementation
//===----------------------------------------------------------------------===//
@@ -70,7 +78,7 @@ LMULType &LMULType::operator*=(uint32_t RHS) {
return *this;
}
-RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
+RVVType::RVVType(BasicType BT, int Log2LMUL, const TypeProfile &prototype)
: BT(BT), LMUL(LMULType(Log2LMUL)) {
applyBasicType();
applyModifier(prototype);
@@ -326,31 +334,31 @@ void RVVType::initShortStr() {
void RVVType::applyBasicType() {
switch (BT) {
- case 'c':
+ case BasicType::Int8:
ElementBitwidth = 8;
ScalarType = ScalarTypeKind::SignedInteger;
break;
- case 's':
+ case BasicType::Int16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::SignedInteger;
break;
- case 'i':
+ case BasicType::Int32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::SignedInteger;
break;
- case 'l':
+ case BasicType::Int64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::SignedInteger;
break;
- case 'x':
+ case BasicType::Float16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::Float;
break;
- case 'f':
+ case BasicType::Float32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::Float;
break;
- case 'd':
+ case BasicType::Float64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::Float;
break;
@@ -360,162 +368,460 @@ void RVVType::applyBasicType() {
assert(ElementBitwidth != 0 && "Bad element bitwidth!");
}
-void RVVType::applyModifier(StringRef Transformer) {
- if (Transformer.empty())
- return;
+Optional<TypeProfile>
+TypeProfile::parseTypeProfile(llvm::StringRef TypeProfileStr) {
+ TypeProfile TP;
+ PrimitiveType PT = PrimitiveType::Invalid;
+ if (TypeProfileStr.empty())
+ return TP;
// Handle primitive type transformer
- auto PType = Transformer.back();
+ auto PType = TypeProfileStr.back();
switch (PType) {
case 'e':
- Scale = 0;
+ PT = PrimitiveType::Scalar;
break;
case 'v':
- Scale = LMUL.getScale(ElementBitwidth);
+ PT = PrimitiveType::Vector;
break;
case 'w':
- ElementBitwidth *= 2;
- LMUL *= 2;
- Scale = LMUL.getScale(ElementBitwidth);
+ PT = PrimitiveType::Widening2XVector;
break;
case 'q':
- ElementBitwidth *= 4;
- LMUL *= 4;
- Scale = LMUL.getScale(ElementBitwidth);
+ PT = PrimitiveType::Widening4XVector;
break;
case 'o':
- ElementBitwidth *= 8;
- LMUL *= 8;
- Scale = LMUL.getScale(ElementBitwidth);
+ PT = PrimitiveType::Widening8XVector;
break;
case 'm':
- ScalarType = ScalarTypeKind::Boolean;
- Scale = LMUL.getScale(ElementBitwidth);
- ElementBitwidth = 1;
+ PT = PrimitiveType::MaskVector;
break;
case '0':
- ScalarType = ScalarTypeKind::Void;
+ PT = PrimitiveType::Void;
break;
case 'z':
- ScalarType = ScalarTypeKind::Size_t;
+ PT = PrimitiveType::SizeT;
break;
case 't':
- ScalarType = ScalarTypeKind::Ptr
diff _t;
+ PT = PrimitiveType::Ptr
diff ;
break;
case 'u':
- ScalarType = ScalarTypeKind::UnsignedLong;
+ PT = PrimitiveType::UnsignedLong;
break;
case 'l':
- ScalarType = ScalarTypeKind::SignedLong;
+ PT = PrimitiveType::SignedLong;
break;
default:
llvm_unreachable("Illegal primitive type transformers!");
}
- Transformer = Transformer.drop_back();
+ TP.PT = static_cast<uint8_t>(PT);
+ TypeProfileStr = TypeProfileStr.drop_back();
// Extract and compute complex type transformer. It can only appear one time.
- if (Transformer.startswith("(")) {
- size_t Idx = Transformer.find(')');
+ if (TypeProfileStr.startswith("(")) {
+ size_t Idx = TypeProfileStr.find(')');
assert(Idx != StringRef::npos);
- StringRef ComplexType = Transformer.slice(1, Idx);
- Transformer = Transformer.drop_front(Idx + 1);
- assert(!Transformer.contains('(') &&
+ StringRef ComplexType = TypeProfileStr.slice(1, Idx);
+ TypeProfileStr = TypeProfileStr.drop_front(Idx + 1);
+ assert(!TypeProfileStr.contains('(') &&
"Only allow one complex type transformer");
- auto UpdateAndCheckComplexProto = [&]() {
- Scale = LMUL.getScale(ElementBitwidth);
- const StringRef VectorPrototypes("vwqom");
- if (!VectorPrototypes.contains(PType))
- llvm_unreachable("Complex type transformer only supports vector type!");
- if (Transformer.find_first_of("PCKWS") != StringRef::npos)
- llvm_unreachable(
- "Illegal type transformer for Complex type transformer");
- };
- auto ComputeFixedLog2LMUL =
- [&](StringRef Value,
- std::function<bool(const int32_t &, const int32_t &)> Compare) {
- int32_t Log2LMUL;
- Value.getAsInteger(10, Log2LMUL);
- if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
- ScalarType = Invalid;
- return false;
- }
- // Update new LMUL
- LMUL = LMULType(Log2LMUL);
- UpdateAndCheckComplexProto();
- return true;
- };
auto ComplexTT = ComplexType.split(":");
+ VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
if (ComplexTT.first == "Log2EEW") {
uint32_t Log2EEW;
- ComplexTT.second.getAsInteger(10, Log2EEW);
- // update new elmul = (eew/sew) * lmul
- LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
- // update new eew
- ElementBitwidth = 1 << Log2EEW;
- ScalarType = ScalarTypeKind::SignedInteger;
- UpdateAndCheckComplexProto();
+ if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
+ llvm_unreachable("Invalid Log2EEW value!");
+ return None;
+ }
+ switch (Log2EEW) {
+ case 3:
+ VTM = VectorTypeModifier::Log2EEW3;
+ break;
+ case 4:
+ VTM = VectorTypeModifier::Log2EEW4;
+ break;
+ case 5:
+ VTM = VectorTypeModifier::Log2EEW5;
+ break;
+ case 6:
+ VTM = VectorTypeModifier::Log2EEW6;
+ break;
+ default:
+ llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
+ return None;
+ }
} else if (ComplexTT.first == "FixedSEW") {
uint32_t NewSEW;
- ComplexTT.second.getAsInteger(10, NewSEW);
- // Set invalid type if src and dst SEW are same.
- if (ElementBitwidth == NewSEW) {
- ScalarType = Invalid;
- return;
+ if (ComplexTT.second.getAsInteger(10, NewSEW)) {
+ llvm_unreachable("Invalid FixedSEW value!");
+ return None;
+ }
+ switch (NewSEW) {
+ case 8:
+ VTM = VectorTypeModifier::FixedSEW8;
+ break;
+ case 16:
+ VTM = VectorTypeModifier::FixedSEW16;
+ break;
+ case 32:
+ VTM = VectorTypeModifier::FixedSEW32;
+ break;
+ case 64:
+ VTM = VectorTypeModifier::FixedSEW64;
+ break;
+ default:
+ llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
+ return None;
}
- // Update new SEW
- ElementBitwidth = NewSEW;
- UpdateAndCheckComplexProto();
} else if (ComplexTT.first == "LFixedLog2LMUL") {
- // New LMUL should be larger than old
- if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
- return;
+ int32_t Log2LMUL;
+ if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+ llvm_unreachable("Invalid LFixedLog2LMUL value!");
+ return None;
+ }
+ switch (Log2LMUL) {
+ case -3:
+ VTM = VectorTypeModifier::LFixedLog2LMULN3;
+ break;
+ case -2:
+ VTM = VectorTypeModifier::LFixedLog2LMULN2;
+ break;
+ case -1:
+ VTM = VectorTypeModifier::LFixedLog2LMULN1;
+ break;
+ case 0:
+ VTM = VectorTypeModifier::LFixedLog2LMUL0;
+ break;
+ case 1:
+ VTM = VectorTypeModifier::LFixedLog2LMUL1;
+ break;
+ case 2:
+ VTM = VectorTypeModifier::LFixedLog2LMUL2;
+ break;
+ case 3:
+ VTM = VectorTypeModifier::LFixedLog2LMUL3;
+ break;
+ default:
+ llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+ return None;
+ }
} else if (ComplexTT.first == "SFixedLog2LMUL") {
- // New LMUL should be smaller than old
- if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
- return;
+ int32_t Log2LMUL;
+ if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+ llvm_unreachable("Invalid SFixedLog2LMUL value!");
+ return None;
+ }
+ switch (Log2LMUL) {
+ case -3:
+ VTM = VectorTypeModifier::SFixedLog2LMULN3;
+ break;
+ case -2:
+ VTM = VectorTypeModifier::SFixedLog2LMULN2;
+ break;
+ case -1:
+ VTM = VectorTypeModifier::SFixedLog2LMULN1;
+ break;
+ case 0:
+ VTM = VectorTypeModifier::SFixedLog2LMUL0;
+ break;
+ case 1:
+ VTM = VectorTypeModifier::SFixedLog2LMUL1;
+ break;
+ case 2:
+ VTM = VectorTypeModifier::SFixedLog2LMUL2;
+ break;
+ case 3:
+ VTM = VectorTypeModifier::SFixedLog2LMUL3;
+ break;
+ default:
+ llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+ return None;
+ }
+
} else {
llvm_unreachable("Illegal complex type transformers!");
}
+ TP.VTM = static_cast<uint8_t>(VTM);
}
// Compute the remain type transformers
- for (char I : Transformer) {
+ TypeModifier TM = TypeModifier::NoModifier;
+ for (char I : TypeProfileStr) {
switch (I) {
case 'P':
- if (IsConstant)
+ if ((TM & TypeModifier::Const) == TypeModifier::Const)
llvm_unreachable("'P' transformer cannot be used after 'C'");
- if (IsPointer)
+ if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
llvm_unreachable("'P' transformer cannot be used twice");
- IsPointer = true;
+ TM |= TypeModifier::Pointer;
break;
case 'C':
- if (IsConstant)
- llvm_unreachable("'C' transformer cannot be used twice");
- IsConstant = true;
+ TM |= TypeModifier::Const;
break;
case 'K':
- IsImmediate = true;
+ TM |= TypeModifier::Immediate;
break;
case 'U':
- ScalarType = ScalarTypeKind::UnsignedInteger;
+ TM |= TypeModifier::UnsignedInteger;
break;
case 'I':
- ScalarType = ScalarTypeKind::SignedInteger;
+ TM |= TypeModifier::SignedInteger;
break;
case 'F':
- ScalarType = ScalarTypeKind::Float;
+ TM |= TypeModifier::Float;
break;
case 'S':
+ TM |= TypeModifier::LMUL1;
+ break;
+ default:
+ llvm_unreachable("Illegal non-primitive type transformer!");
+ }
+ }
+ TP.TM = static_cast<uint8_t>(TM);
+
+ return TP;
+}
+
+void RVVType::applyModifier(const TypeProfile &Transformer) {
+ // Handle primitive type transformer
+ switch (static_cast<PrimitiveType>(Transformer.PT)) {
+ case PrimitiveType::Scalar:
+ Scale = 0;
+ break;
+ case PrimitiveType::Vector:
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case PrimitiveType::Widening2XVector:
+ ElementBitwidth *= 2;
+ LMUL *= 2;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case PrimitiveType::Widening4XVector:
+ ElementBitwidth *= 4;
+ LMUL *= 4;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case PrimitiveType::Widening8XVector:
+ ElementBitwidth *= 8;
+ LMUL *= 8;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case PrimitiveType::MaskVector:
+ ScalarType = ScalarTypeKind::Boolean;
+ Scale = LMUL.getScale(ElementBitwidth);
+ ElementBitwidth = 1;
+ break;
+ case PrimitiveType::Void:
+ ScalarType = ScalarTypeKind::Void;
+ break;
+ case PrimitiveType::SizeT:
+ ScalarType = ScalarTypeKind::Size_t;
+ break;
+ case PrimitiveType::Ptr
diff :
+ ScalarType = ScalarTypeKind::Ptr
diff _t;
+ break;
+ case PrimitiveType::UnsignedLong:
+ ScalarType = ScalarTypeKind::UnsignedLong;
+ break;
+ case PrimitiveType::SignedLong:
+ ScalarType = ScalarTypeKind::SignedLong;
+ break;
+ case PrimitiveType::Invalid:
+ ScalarType = ScalarTypeKind::Invalid;
+ return;
+ default:
+ llvm_unreachable("Illegal primitive type transformers!");
+ }
+
+ switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
+ case VectorTypeModifier::Log2EEW3:
+ applyLog2EEW(3);
+ break;
+ case VectorTypeModifier::Log2EEW4:
+ applyLog2EEW(4);
+ break;
+ case VectorTypeModifier::Log2EEW5:
+ applyLog2EEW(5);
+ break;
+ case VectorTypeModifier::Log2EEW6:
+ applyLog2EEW(6);
+ break;
+ case VectorTypeModifier::FixedSEW8:
+ applyFixedSEW(8);
+ break;
+ case VectorTypeModifier::FixedSEW16:
+ applyFixedSEW(16);
+ break;
+ case VectorTypeModifier::FixedSEW32:
+ applyFixedSEW(32);
+ break;
+ case VectorTypeModifier::FixedSEW64:
+ applyFixedSEW(64);
+ break;
+ case VectorTypeModifier::LFixedLog2LMULN3:
+ applyFixedLog2LMUL(-3, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMULN2:
+ applyFixedLog2LMUL(-2, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMULN1:
+ applyFixedLog2LMUL(-1, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMUL0:
+ applyFixedLog2LMUL(0, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMUL1:
+ applyFixedLog2LMUL(1, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMUL2:
+ applyFixedLog2LMUL(2, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::LFixedLog2LMUL3:
+ applyFixedLog2LMUL(3, /* LargerThan= */ true);
+ break;
+ case VectorTypeModifier::SFixedLog2LMULN3:
+ applyFixedLog2LMUL(-3, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMULN2:
+ applyFixedLog2LMUL(-2, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMULN1:
+ applyFixedLog2LMUL(-1, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMUL0:
+ applyFixedLog2LMUL(0, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMUL1:
+ applyFixedLog2LMUL(1, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMUL2:
+ applyFixedLog2LMUL(2, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::SFixedLog2LMUL3:
+ applyFixedLog2LMUL(3, /* LargerThan= */ false);
+ break;
+ case VectorTypeModifier::NoModifier:
+ break;
+ default:
+ llvm_unreachable("Illegal vector type modifier!");
+ }
+
+ for (unsigned TypeModifierMaskShift = 0;
+ TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
+ ++TypeModifierMaskShift) {
+ unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
+ if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
+ TypeModifierMask)
+ continue;
+ switch (static_cast<TypeModifier>(TypeModifierMask)) {
+ case TypeModifier::Pointer:
+ IsPointer = true;
+ break;
+ case TypeModifier::Const:
+ IsConstant = true;
+ break;
+ case TypeModifier::Immediate:
+ IsImmediate = true;
+ IsConstant = true;
+ break;
+ case TypeModifier::UnsignedInteger:
+ ScalarType = ScalarTypeKind::UnsignedInteger;
+ break;
+ case TypeModifier::SignedInteger:
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case TypeModifier::Float:
+ ScalarType = ScalarTypeKind::Float;
+ break;
+ case TypeModifier::LMUL1:
LMUL = LMULType(0);
// Update ElementBitwidth need to update Scale too.
Scale = LMUL.getScale(ElementBitwidth);
break;
default:
- llvm_unreachable("Illegal non-primitive type transformer!");
+ llvm_unreachable("Unknown type modifier mask!");
}
}
}
+void RVVType::applyLog2EEW(unsigned Log2EEW) {
+ // update new elmul = (eew/sew) * lmul
+ LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
+ // update new eew
+ ElementBitwidth = 1 << Log2EEW;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedSEW(unsigned NewSEW) {
+ // Set invalid type if src and dst SEW are same.
+ if (ElementBitwidth == NewSEW) {
+ ScalarType = ScalarTypeKind::Invalid;
+ return;
+ }
+ // Update new SEW
+ ElementBitwidth = NewSEW;
+ Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) {
+ if (LargerThan) {
+ if (Log2LMUL < LMUL.Log2LMUL) {
+ ScalarType = ScalarTypeKind::Invalid;
+ return;
+ }
+ } else {
+ if (Log2LMUL > LMUL.Log2LMUL) {
+ ScalarType = ScalarTypeKind::Invalid;
+ return;
+ }
+ }
+ // Update new LMUL
+ LMUL = LMULType(Log2LMUL);
+ Scale = LMUL.getScale(ElementBitwidth);
+}
+
+Optional<RVVTypes> RVVType::computeTypes(BasicType BT, int Log2LMUL,
+ unsigned NF,
+ ArrayRef<TypeProfile> PrototypeSeq) {
+ // LMUL x NF must be less than or equal to 8.
+ if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
+ return llvm::None;
+
+ RVVTypes Types;
+ for (const TypeProfile &Proto : PrototypeSeq) {
+ auto T = computeType(BT, Log2LMUL, Proto);
+ if (!T.hasValue())
+ return llvm::None;
+ // Record legal type index
+ Types.push_back(T.getValue());
+ }
+ return Types;
+}
+
+Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
+ TypeProfile Proto) {
+ std::string Idx =
+ Twine(Twine(static_cast<int>(BT)) + Twine(Log2LMUL) + Proto.IndexStr())
+ .str();
+ // Search first
+ auto It = LegalTypes.find(Idx);
+ if (It != LegalTypes.end())
+ return &(It->second);
+ if (IllegalTypes.count(Idx))
+ return llvm::None;
+ // Compute type and record the result.
+ RVVType T(BT, Log2LMUL, Proto);
+ if (T.isValid()) {
+ // Record legal type index and value.
+ LegalTypes.insert({Idx, T});
+ return &(LegalTypes[Idx]);
+ }
+ // Record illegal type index.
+ IllegalTypes.insert(Idx);
+ return llvm::None;
+}
+
//===----------------------------------------------------------------------===//
// RVVIntrinsic implementation
//===----------------------------------------------------------------------===//
@@ -593,5 +899,36 @@ std::string RVVIntrinsic::getBuiltinTypeStr() const {
return S;
}
+std::string
+RVVIntrinsic::getSuffixStr(BasicType Type, int Log2LMUL,
+ const llvm::SmallVector<TypeProfile> &TypeProfiles) {
+ SmallVector<std::string> SuffixStrs;
+ for (auto TP : TypeProfiles) {
+ auto T = RVVType::computeType(Type, Log2LMUL, TP);
+ SuffixStrs.push_back(T.getValue()->getShortStr());
+ }
+ return join(SuffixStrs, "_");
+}
+
+SmallVector<TypeProfile> parsePrototypes(StringRef Prototypes) {
+ SmallVector<TypeProfile> TypeProfiles;
+ const StringRef Primaries("evwqom0ztul");
+ while (!Prototypes.empty()) {
+ size_t Idx = 0;
+ // Skip over complex prototype because it could contain primitive type
+ // character.
+ if (Prototypes[0] == '(')
+ Idx = Prototypes.find_first_of(')');
+ Idx = Prototypes.find_first_of(Primaries, Idx);
+ assert(Idx != StringRef::npos);
+ auto TP = TypeProfile::parseTypeProfile(Prototypes.slice(0, Idx + 1));
+ if (!TP)
+ llvm_unreachable("Error during parsing prototype.");
+ TypeProfiles.push_back(*TP);
+ Prototypes = Prototypes.drop_front(Idx + 1);
+ }
+ return std::move(TypeProfiles);
+}
+
} // end namespace RISCV
} // end namespace clang
diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp
index bd9e74f2f0cf7..e8083d99f4742 100644
--- a/clang/utils/TableGen/RISCVVEmitter.cpp
+++ b/clang/utils/TableGen/RISCVVEmitter.cpp
@@ -32,9 +32,6 @@ namespace {
class RVVEmitter {
private:
RecordKeeper &Records;
- // Concat BasicType, LMUL and Proto as key
- StringMap<RVVType> LegalTypes;
- StringSet<> IllegalTypes;
public:
RVVEmitter(RecordKeeper &R) : Records(R) {}
@@ -48,20 +45,11 @@ class RVVEmitter {
/// Emit all the information needed to map builtin -> LLVM IR intrinsic.
void createCodeGen(raw_ostream &o);
- std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
-
private:
/// Create all intrinsics and add them to \p Out
void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
/// Print HeaderCode in RVVHeader Record to \p Out
void printHeaderCode(raw_ostream &OS);
- /// Compute output and input types by applying
diff erent config (basic type
- /// and LMUL with type transformers). It also record result of type in legal
- /// or illegal set to avoid compute the same config again. The result maybe
- /// have illegal RVVType.
- Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
- ArrayRef<std::string> PrototypeSeq);
- Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
/// Emit Acrh predecessor definitions and body, assume the element of Defs are
/// sorted by extension.
@@ -73,14 +61,39 @@ class RVVEmitter {
// non-empty string.
bool emitMacroRestrictionStr(RISCVPredefinedMacroT PredefinedMacros,
raw_ostream &o);
- // Slice Prototypes string into sub prototype string and process each sub
- // prototype string individually in the Handler.
- void parsePrototypes(StringRef Prototypes,
- std::function<void(StringRef)> Handler);
};
} // namespace
+static BasicType ParseBasicType(char c) {
+ switch (c) {
+ case 'c':
+ return BasicType::Int8;
+ break;
+ case 's':
+ return BasicType::Int16;
+ break;
+ case 'i':
+ return BasicType::Int32;
+ break;
+ case 'l':
+ return BasicType::Int64;
+ break;
+ case 'x':
+ return BasicType::Float16;
+ break;
+ case 'f':
+ return BasicType::Float32;
+ break;
+ case 'd':
+ return BasicType::Float64;
+ break;
+
+ default:
+ return BasicType::Unknown;
+ }
+}
+
void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) {
if (!RVVI->getIRName().empty())
OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n";
@@ -202,24 +215,28 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
// Print RVV boolean types.
for (int Log2LMUL : Log2LMULs) {
- auto T = computeType('c', Log2LMUL, "m");
+ auto T = RVVType::computeType(BasicType::Int8, Log2LMUL, TypeProfile::Mask);
if (T.hasValue())
printType(T.getValue());
}
// Print RVV int/float types.
for (char I : StringRef("csil")) {
+ BasicType BT = ParseBasicType(I);
for (int Log2LMUL : Log2LMULs) {
- auto T = computeType(I, Log2LMUL, "v");
+ auto T = RVVType::computeType(BT, Log2LMUL, TypeProfile::Vector);
if (T.hasValue()) {
printType(T.getValue());
- auto UT = computeType(I, Log2LMUL, "Uv");
+ auto UT = RVVType::computeType(
+ BT, Log2LMUL,
+ TypeProfile(PrimitiveType::Vector, TypeModifier::UnsignedInteger));
printType(UT.getValue());
}
}
}
OS << "#if defined(__riscv_zvfh)\n";
for (int Log2LMUL : Log2LMULs) {
- auto T = computeType('x', Log2LMUL, "v");
+ auto T =
+ RVVType::computeType(BasicType::Float16, Log2LMUL, TypeProfile::Vector);
if (T.hasValue())
printType(T.getValue());
}
@@ -227,7 +244,8 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
OS << "#if defined(__riscv_f)\n";
for (int Log2LMUL : Log2LMULs) {
- auto T = computeType('f', Log2LMUL, "v");
+ auto T =
+ RVVType::computeType(BasicType::Float32, Log2LMUL, TypeProfile::Vector);
if (T.hasValue())
printType(T.getValue());
}
@@ -235,7 +253,8 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
OS << "#if defined(__riscv_d)\n";
for (int Log2LMUL : Log2LMULs) {
- auto T = computeType('d', Log2LMUL, "v");
+ auto T =
+ RVVType::computeType(BasicType::Float64, Log2LMUL, TypeProfile::Vector);
if (T.hasValue())
printType(T.getValue());
}
@@ -359,32 +378,6 @@ void RVVEmitter::createCodeGen(raw_ostream &OS) {
OS << "\n";
}
-void RVVEmitter::parsePrototypes(StringRef Prototypes,
- std::function<void(StringRef)> Handler) {
- const StringRef Primaries("evwqom0ztul");
- while (!Prototypes.empty()) {
- size_t Idx = 0;
- // Skip over complex prototype because it could contain primitive type
- // character.
- if (Prototypes[0] == '(')
- Idx = Prototypes.find_first_of(')');
- Idx = Prototypes.find_first_of(Primaries, Idx);
- assert(Idx != StringRef::npos);
- Handler(Prototypes.slice(0, Idx + 1));
- Prototypes = Prototypes.drop_front(Idx + 1);
- }
-}
-
-std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
- StringRef Prototypes) {
- SmallVector<std::string> SuffixStrs;
- parsePrototypes(Prototypes, [&](StringRef Proto) {
- auto T = computeType(Type, Log2LMUL, Proto);
- SuffixStrs.push_back(T.getValue()->getShortStr());
- });
- return join(SuffixStrs, "_");
-}
-
void RVVEmitter::createRVVIntrinsics(
std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
@@ -419,13 +412,14 @@ void RVVEmitter::createRVVIntrinsics(
// Parse prototype and create a list of primitive type with transformers
// (operand) in ProtoSeq. ProtoSeq[0] is output operand.
- SmallVector<std::string> ProtoSeq;
- parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
- ProtoSeq.push_back(Proto.str());
- });
+ SmallVector<TypeProfile> ProtoSeq = parsePrototypes(Prototypes);
+
+ SmallVector<TypeProfile> SuffixProtoSeq = parsePrototypes(SuffixProto);
+ SmallVector<TypeProfile> MangledSuffixProtoSeq =
+ parsePrototypes(MangledSuffixProto);
// Compute Builtin types
- SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
+ SmallVector<TypeProfile> ProtoMaskSeq = ProtoSeq;
if (HasMasked) {
// If HasMaskedOffOperand, insert result type as first input operand.
if (HasMaskedOffOperand) {
@@ -436,10 +430,10 @@ void RVVEmitter::createRVVIntrinsics(
// (void, op0 address, op1 address, ...)
// to
// (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
+ TypeProfile MaskoffType = ProtoSeq[1];
+ MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
for (unsigned I = 0; I < NF; ++I)
- ProtoMaskSeq.insert(
- ProtoMaskSeq.begin() + NF + 1,
- ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
+ ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType);
}
}
if (HasMaskedOffOperand && NF > 1) {
@@ -448,28 +442,32 @@ void RVVEmitter::createRVVIntrinsics(
// to
// (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
// ...)
- ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
+ ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, TypeProfile::Mask);
} else {
- // If HasMasked, insert 'm' as first input operand.
- ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
+ // If HasMasked, insert TypeProfile:Mask as first input operand.
+ ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, TypeProfile::Mask);
}
}
- // If HasVL, append 'z' to last operand
+ // If HasVL, append TypeProfile:VL to last operand
if (HasVL) {
- ProtoSeq.push_back("z");
- ProtoMaskSeq.push_back("z");
+ ProtoSeq.push_back(TypeProfile::VL);
+ ProtoMaskSeq.push_back(TypeProfile::VL);
}
// Create Intrinsics for each type and LMUL.
for (char I : TypeRange) {
for (int Log2LMUL : Log2LMULList) {
- Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
+ BasicType BT = ParseBasicType(I);
+ Optional<RVVTypes> Types =
+ RVVType::computeTypes(BT, Log2LMUL, NF, ProtoSeq);
// Ignored to create new intrinsic if there are any illegal types.
if (!Types.hasValue())
continue;
- auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
- auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
+ auto SuffixStr =
+ RVVIntrinsic::getSuffixStr(BT, Log2LMUL, SuffixProtoSeq);
+ auto MangledSuffixStr =
+ RVVIntrinsic::getSuffixStr(BT, Log2LMUL, MangledSuffixProtoSeq);
// Create a unmasked intrinsic
Out.push_back(std::make_unique<RVVIntrinsic>(
Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
@@ -480,7 +478,7 @@ void RVVEmitter::createRVVIntrinsics(
if (HasMasked) {
// Create a masked intrinsic
Optional<RVVTypes> MaskTypes =
- computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
+ RVVType::computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq);
Out.push_back(std::make_unique<RVVIntrinsic>(
Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName,
/*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy,
@@ -501,45 +499,6 @@ void RVVEmitter::printHeaderCode(raw_ostream &OS) {
}
}
-Optional<RVVTypes>
-RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
- ArrayRef<std::string> PrototypeSeq) {
- // LMUL x NF must be less than or equal to 8.
- if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
- return llvm::None;
-
- RVVTypes Types;
- for (const std::string &Proto : PrototypeSeq) {
- auto T = computeType(BT, Log2LMUL, Proto);
- if (!T.hasValue())
- return llvm::None;
- // Record legal type index
- Types.push_back(T.getValue());
- }
- return Types;
-}
-
-Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
- StringRef Proto) {
- std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
- // Search first
- auto It = LegalTypes.find(Idx);
- if (It != LegalTypes.end())
- return &(It->second);
- if (IllegalTypes.count(Idx))
- return llvm::None;
- // Compute type and record the result.
- RVVType T(BT, Log2LMUL, Proto);
- if (T.isValid()) {
- // Record legal type index and value.
- LegalTypes.insert({Idx, T});
- return &(LegalTypes[Idx]);
- }
- // Record illegal type index.
- IllegalTypes.insert(Idx);
- return llvm::None;
-}
-
void RVVEmitter::emitArchMacroAndBody(
std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
More information about the llvm-branch-commits
mailing list