[clang] f26c41e - [RISCV] Moving RVV intrinsic type related util to clang/Support
Kito Cheng via cfe-commits
cfe-commits at lists.llvm.org
Wed Apr 20 06:13:18 PDT 2022
Author: Kito Cheng
Date: 2022-04-20T21:13:13+08:00
New Revision: f26c41e8dd28d86030cd0f5a6e9c11036acea5d2
URL: https://github.com/llvm/llvm-project/commit/f26c41e8dd28d86030cd0f5a6e9c11036acea5d2
DIFF: https://github.com/llvm/llvm-project/commit/f26c41e8dd28d86030cd0f5a6e9c11036acea5d2.diff
LOG: [RISCV] Moving RVV intrinsic type related util to clang/Support
We add a new clang library called `clangSupport` for putting those utils which can be used in clang table-gen and other clang component.
We tried to put that into `llvm/Support`, but actually those stuffs only used in clang* and clang-tblgen, so I think that might be better to create `clang/Support`
* clang will used that in https://reviews.llvm.org/D111617.
Reviewed By: khchen, MaskRay, aaron.ballman
Differential Revision: https://reviews.llvm.org/D121984
Added:
clang/include/clang/Support/RISCVVIntrinsicUtils.h
clang/lib/Support/CMakeLists.txt
clang/lib/Support/RISCVVIntrinsicUtils.cpp
Modified:
clang/lib/CMakeLists.txt
clang/utils/TableGen/CMakeLists.txt
clang/utils/TableGen/RISCVVEmitter.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
new file mode 100644
index 0000000000000..1a4947d0c3df3
--- /dev/null
+++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -0,0 +1,215 @@
+//===--- RISCVVIntrinsicUtils.h - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
+#define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include <cstdint>
+#include <string>
+#include <vector>
+
+namespace clang {
+namespace RISCV {
+
+using BasicType = char;
+using VScaleVal = llvm::Optional<unsigned>;
+
+// Exponential LMUL
+struct LMULType {
+ int Log2LMUL;
+ LMULType(int Log2LMUL);
+ // Return the C/C++ string representation of LMUL
+ std::string str() const;
+ llvm::Optional<unsigned> getScale(unsigned ElementBitwidth) const;
+ void MulLog2LMUL(int Log2LMUL);
+ LMULType &operator*=(uint32_t RHS);
+};
+
+// This class is compact representation of a valid and invalid RVVType.
+class RVVType {
+ enum ScalarTypeKind : uint32_t {
+ Void,
+ Size_t,
+ Ptr
diff _t,
+ UnsignedLong,
+ SignedLong,
+ Boolean,
+ SignedInteger,
+ UnsignedInteger,
+ Float,
+ Invalid,
+ };
+ BasicType BT;
+ ScalarTypeKind ScalarType = Invalid;
+ LMULType LMUL;
+ bool IsPointer = false;
+ // IsConstant indices are "int", but have the constant expression.
+ bool IsImmediate = false;
+ // Const qualifier for pointer to const object or object of const type.
+ bool IsConstant = false;
+ unsigned ElementBitwidth = 0;
+ VScaleVal Scale = 0;
+ bool Valid;
+
+ std::string BuiltinStr;
+ std::string ClangBuiltinStr;
+ std::string Str;
+ std::string ShortStr;
+
+public:
+ RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {}
+ RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype);
+
+ // Return the string representation of a type, which is an encoded string for
+ // passing to the BUILTIN() macro in Builtins.def.
+ const std::string &getBuiltinStr() const { return BuiltinStr; }
+
+ // Return the clang builtin type for RVV vector type which are used in the
+ // riscv_vector.h header file.
+ const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
+
+ // Return the C/C++ string representation of a type for use in the
+ // riscv_vector.h header file.
+ const std::string &getTypeStr() const { return Str; }
+
+ // Return the short name of a type for C/C++ name suffix.
+ const std::string &getShortStr() {
+ // Not all types are used in short name, so compute the short name by
+ // demanded.
+ if (ShortStr.empty())
+ initShortStr();
+ return ShortStr;
+ }
+
+ bool isValid() const { return Valid; }
+ bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
+ bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
+ bool isVector(unsigned Width) const {
+ return isVector() && ElementBitwidth == Width;
+ }
+ bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
+ bool isSignedInteger() const {
+ return ScalarType == ScalarTypeKind::SignedInteger;
+ }
+ bool isFloatVector(unsigned Width) const {
+ return isVector() && isFloat() && ElementBitwidth == Width;
+ }
+ bool isFloat(unsigned Width) const {
+ return isFloat() && ElementBitwidth == Width;
+ }
+
+private:
+ // Verify RVV vector type and set Valid.
+ bool verifyType() const;
+
+ // Creates a type based on basic types of TypeRange
+ void applyBasicType();
+
+ // Applies a prototype modifier to the current type. The result maybe an
+ // invalid type.
+ void applyModifier(llvm::StringRef prototype);
+
+ // Compute and record a string for legal type.
+ void initBuiltinStr();
+ // Compute and record a builtin RVV vector type string.
+ void initClangBuiltinStr();
+ // Compute and record a type string for used in the header.
+ void initTypeStr();
+ // Compute and record a short name of a type for C/C++ name suffix.
+ void initShortStr();
+};
+
+using RVVTypePtr = RVVType *;
+using RVVTypes = std::vector<RVVTypePtr>;
+using RISCVPredefinedMacroT = uint8_t;
+
+enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
+ Basic = 0,
+ V = 1 << 1,
+ Zvfh = 1 << 2,
+ RV64 = 1 << 3,
+ VectorMaxELen64 = 1 << 4,
+ VectorMaxELenFp32 = 1 << 5,
+ VectorMaxELenFp64 = 1 << 6,
+};
+
+enum PolicyScheme : uint8_t {
+ SchemeNone,
+ HasPassthruOperand,
+ HasPolicyOperand,
+};
+
+// TODO refactor RVVIntrinsic class design after support all intrinsic
+// combination. This represents an instantiation of an intrinsic with a
+// particular type and prototype
+class RVVIntrinsic {
+
+private:
+ std::string BuiltinName; // Builtin name
+ std::string Name; // C intrinsic name.
+ std::string MangledName;
+ std::string IRName;
+ bool IsMasked;
+ bool HasVL;
+ PolicyScheme Scheme;
+ bool HasUnMaskedOverloaded;
+ bool HasBuiltinAlias;
+ std::string ManualCodegen;
+ RVVTypePtr OutputType; // Builtin output type
+ RVVTypes InputTypes; // Builtin input types
+ // The types we use to obtain the specific LLVM intrinsic. They are index of
+ // InputTypes. -1 means the return type.
+ std::vector<int64_t> IntrinsicTypes;
+ RISCVPredefinedMacroT RISCVPredefinedMacros = 0;
+ unsigned NF = 1;
+
+public:
+ RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix, llvm::StringRef MangledName,
+ llvm::StringRef MangledSuffix, llvm::StringRef IRName, bool IsMasked,
+ bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
+ bool HasUnMaskedOverloaded, bool HasBuiltinAlias,
+ llvm::StringRef ManualCodegen, const RVVTypes &Types,
+ const std::vector<int64_t> &IntrinsicTypes,
+ const std::vector<llvm::StringRef> &RequiredFeatures, unsigned NF);
+ ~RVVIntrinsic() = default;
+
+ RVVTypePtr getOutputType() const { return OutputType; }
+ const RVVTypes &getInputTypes() const { return InputTypes; }
+ llvm::StringRef getBuiltinName() const { return BuiltinName; }
+ llvm::StringRef getName() const { return Name; }
+ llvm::StringRef getMangledName() const { return MangledName; }
+ bool hasVL() const { return HasVL; }
+ bool hasPolicy() const { return Scheme != SchemeNone; }
+ bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; }
+ bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; }
+ bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; }
+ bool hasBuiltinAlias() const { return HasBuiltinAlias; }
+ bool hasManualCodegen() const { return !ManualCodegen.empty(); }
+ bool isMasked() const { return IsMasked; }
+ llvm::StringRef getIRName() const { return IRName; }
+ llvm::StringRef getManualCodegen() const { return ManualCodegen; }
+ PolicyScheme getPolicyScheme() const { return Scheme; }
+ RISCVPredefinedMacroT getRISCVPredefinedMacros() const {
+ return RISCVPredefinedMacros;
+ }
+ unsigned getNF() const { return NF; }
+ const std::vector<int64_t> &getIntrinsicTypes() const {
+ return IntrinsicTypes;
+ }
+
+ // Return the type string for a BUILTIN() macro in Builtins.def.
+ std::string getBuiltinTypeStr() const;
+};
+
+} // end namespace RISCV
+
+} // end namespace clang
+
+#endif // CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
diff --git a/clang/lib/CMakeLists.txt b/clang/lib/CMakeLists.txt
index 50bd0cb55059e..1526d65795f8a 100644
--- a/clang/lib/CMakeLists.txt
+++ b/clang/lib/CMakeLists.txt
@@ -29,3 +29,4 @@ if(CLANG_INCLUDE_TESTS)
add_subdirectory(Testing)
endif()
add_subdirectory(Interpreter)
+add_subdirectory(Support)
diff --git a/clang/lib/Support/CMakeLists.txt b/clang/lib/Support/CMakeLists.txt
new file mode 100644
index 0000000000000..c24324bd7b0d3
--- /dev/null
+++ b/clang/lib/Support/CMakeLists.txt
@@ -0,0 +1,16 @@
+set(LLVM_COMMON_DEPENDS_OLD ${LLVM_COMMON_DEPENDS})
+
+# Drop clang-tablegen-targets from LLVM_COMMON_DEPENDS.
+# so that we could use clangSupport within clang-tblgen and other clang
+# component.
+list(REMOVE_ITEM LLVM_COMMON_DEPENDS clang-tablegen-targets)
+
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
+add_clang_library(clangSupport
+ RISCVVIntrinsicUtils.cpp
+ )
+
+set(LLVM_COMMON_DEPENDS ${LLVM_COMMON_DEPENDS_OLD})
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
new file mode 100644
index 0000000000000..2e2f92d4804fe
--- /dev/null
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -0,0 +1,597 @@
+//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Support/RISCVVIntrinsicUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace llvm;
+
+namespace clang {
+namespace RISCV {
+
+//===----------------------------------------------------------------------===//
+// Type implementation
+//===----------------------------------------------------------------------===//
+
+LMULType::LMULType(int NewLog2LMUL) {
+ // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
+ assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
+ Log2LMUL = NewLog2LMUL;
+}
+
+std::string LMULType::str() const {
+ if (Log2LMUL < 0)
+ return "mf" + utostr(1ULL << (-Log2LMUL));
+ return "m" + utostr(1ULL << Log2LMUL);
+}
+
+VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
+ int Log2ScaleResult = 0;
+ switch (ElementBitwidth) {
+ default:
+ break;
+ case 8:
+ Log2ScaleResult = Log2LMUL + 3;
+ break;
+ case 16:
+ Log2ScaleResult = Log2LMUL + 2;
+ break;
+ case 32:
+ Log2ScaleResult = Log2LMUL + 1;
+ break;
+ case 64:
+ Log2ScaleResult = Log2LMUL;
+ break;
+ }
+ // Illegal vscale result would be less than 1
+ if (Log2ScaleResult < 0)
+ return llvm::None;
+ return 1 << Log2ScaleResult;
+}
+
+void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
+
+LMULType &LMULType::operator*=(uint32_t RHS) {
+ assert(isPowerOf2_32(RHS));
+ this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
+ return *this;
+}
+
+RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
+ : BT(BT), LMUL(LMULType(Log2LMUL)) {
+ applyBasicType();
+ applyModifier(prototype);
+ Valid = verifyType();
+ if (Valid) {
+ initBuiltinStr();
+ initTypeStr();
+ if (isVector()) {
+ initClangBuiltinStr();
+ }
+ }
+}
+
+// clang-format off
+// boolean type are encoded the ratio of n (SEW/LMUL)
+// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
+// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
+// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
+
+// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
+// -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
+// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
+// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
+// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
+// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
+// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
+// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
+// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
+// clang-format on
+
+bool RVVType::verifyType() const {
+ if (ScalarType == Invalid)
+ return false;
+ if (isScalar())
+ return true;
+ if (!Scale.hasValue())
+ return false;
+ if (isFloat() && ElementBitwidth == 8)
+ return false;
+ unsigned V = Scale.getValue();
+ switch (ElementBitwidth) {
+ case 1:
+ case 8:
+ // Check Scale is 1,2,4,8,16,32,64
+ return (V <= 64 && isPowerOf2_32(V));
+ case 16:
+ // Check Scale is 1,2,4,8,16,32
+ return (V <= 32 && isPowerOf2_32(V));
+ case 32:
+ // Check Scale is 1,2,4,8,16
+ return (V <= 16 && isPowerOf2_32(V));
+ case 64:
+ // Check Scale is 1,2,4,8
+ return (V <= 8 && isPowerOf2_32(V));
+ }
+ return false;
+}
+
+void RVVType::initBuiltinStr() {
+ assert(isValid() && "RVVType is invalid");
+ switch (ScalarType) {
+ case ScalarTypeKind::Void:
+ BuiltinStr = "v";
+ return;
+ case ScalarTypeKind::Size_t:
+ BuiltinStr = "z";
+ if (IsImmediate)
+ BuiltinStr = "I" + BuiltinStr;
+ if (IsPointer)
+ BuiltinStr += "*";
+ return;
+ case ScalarTypeKind::Ptr
diff _t:
+ BuiltinStr = "Y";
+ return;
+ case ScalarTypeKind::UnsignedLong:
+ BuiltinStr = "ULi";
+ return;
+ case ScalarTypeKind::SignedLong:
+ BuiltinStr = "Li";
+ return;
+ case ScalarTypeKind::Boolean:
+ assert(ElementBitwidth == 1);
+ BuiltinStr += "b";
+ break;
+ case ScalarTypeKind::SignedInteger:
+ case ScalarTypeKind::UnsignedInteger:
+ switch (ElementBitwidth) {
+ case 8:
+ BuiltinStr += "c";
+ break;
+ case 16:
+ BuiltinStr += "s";
+ break;
+ case 32:
+ BuiltinStr += "i";
+ break;
+ case 64:
+ BuiltinStr += "Wi";
+ break;
+ default:
+ llvm_unreachable("Unhandled ElementBitwidth!");
+ }
+ if (isSignedInteger())
+ BuiltinStr = "S" + BuiltinStr;
+ else
+ BuiltinStr = "U" + BuiltinStr;
+ break;
+ case ScalarTypeKind::Float:
+ switch (ElementBitwidth) {
+ case 16:
+ BuiltinStr += "x";
+ break;
+ case 32:
+ BuiltinStr += "f";
+ break;
+ case 64:
+ BuiltinStr += "d";
+ break;
+ default:
+ llvm_unreachable("Unhandled ElementBitwidth!");
+ }
+ break;
+ default:
+ llvm_unreachable("ScalarType is invalid!");
+ }
+ if (IsImmediate)
+ BuiltinStr = "I" + BuiltinStr;
+ if (isScalar()) {
+ if (IsConstant)
+ BuiltinStr += "C";
+ if (IsPointer)
+ BuiltinStr += "*";
+ return;
+ }
+ BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
+ // Pointer to vector types. Defined for segment load intrinsics.
+ // segment load intrinsics have pointer type arguments to store the loaded
+ // vector values.
+ if (IsPointer)
+ BuiltinStr += "*";
+}
+
+void RVVType::initClangBuiltinStr() {
+ assert(isValid() && "RVVType is invalid");
+ assert(isVector() && "Handle Vector type only");
+
+ ClangBuiltinStr = "__rvv_";
+ switch (ScalarType) {
+ case ScalarTypeKind::Boolean:
+ ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
+ return;
+ case ScalarTypeKind::Float:
+ ClangBuiltinStr += "float";
+ break;
+ case ScalarTypeKind::SignedInteger:
+ ClangBuiltinStr += "int";
+ break;
+ case ScalarTypeKind::UnsignedInteger:
+ ClangBuiltinStr += "uint";
+ break;
+ default:
+ llvm_unreachable("ScalarTypeKind is invalid");
+ }
+ ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
+}
+
+void RVVType::initTypeStr() {
+ assert(isValid() && "RVVType is invalid");
+
+ if (IsConstant)
+ Str += "const ";
+
+ auto getTypeString = [&](StringRef TypeStr) {
+ if (isScalar())
+ return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
+ return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
+ .str();
+ };
+
+ switch (ScalarType) {
+ case ScalarTypeKind::Void:
+ Str = "void";
+ return;
+ case ScalarTypeKind::Size_t:
+ Str = "size_t";
+ if (IsPointer)
+ Str += " *";
+ return;
+ case ScalarTypeKind::Ptr
diff _t:
+ Str = "ptr
diff _t";
+ return;
+ case ScalarTypeKind::UnsignedLong:
+ Str = "unsigned long";
+ return;
+ case ScalarTypeKind::SignedLong:
+ Str = "long";
+ return;
+ case ScalarTypeKind::Boolean:
+ if (isScalar())
+ Str += "bool";
+ else
+ // Vector bool is special case, the formulate is
+ // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
+ Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
+ break;
+ case ScalarTypeKind::Float:
+ if (isScalar()) {
+ if (ElementBitwidth == 64)
+ Str += "double";
+ else if (ElementBitwidth == 32)
+ Str += "float";
+ else if (ElementBitwidth == 16)
+ Str += "_Float16";
+ else
+ llvm_unreachable("Unhandled floating type.");
+ } else
+ Str += getTypeString("float");
+ break;
+ case ScalarTypeKind::SignedInteger:
+ Str += getTypeString("int");
+ break;
+ case ScalarTypeKind::UnsignedInteger:
+ Str += getTypeString("uint");
+ break;
+ default:
+ llvm_unreachable("ScalarType is invalid!");
+ }
+ if (IsPointer)
+ Str += " *";
+}
+
+void RVVType::initShortStr() {
+ switch (ScalarType) {
+ case ScalarTypeKind::Boolean:
+ assert(isVector());
+ ShortStr = "b" + utostr(64 / Scale.getValue());
+ return;
+ case ScalarTypeKind::Float:
+ ShortStr = "f" + utostr(ElementBitwidth);
+ break;
+ case ScalarTypeKind::SignedInteger:
+ ShortStr = "i" + utostr(ElementBitwidth);
+ break;
+ case ScalarTypeKind::UnsignedInteger:
+ ShortStr = "u" + utostr(ElementBitwidth);
+ break;
+ default:
+ llvm_unreachable("Unhandled case!");
+ }
+ if (isVector())
+ ShortStr += LMUL.str();
+}
+
+void RVVType::applyBasicType() {
+ switch (BT) {
+ case 'c':
+ ElementBitwidth = 8;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case 's':
+ ElementBitwidth = 16;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case 'i':
+ ElementBitwidth = 32;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case 'l':
+ ElementBitwidth = 64;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case 'x':
+ ElementBitwidth = 16;
+ ScalarType = ScalarTypeKind::Float;
+ break;
+ case 'f':
+ ElementBitwidth = 32;
+ ScalarType = ScalarTypeKind::Float;
+ break;
+ case 'd':
+ ElementBitwidth = 64;
+ ScalarType = ScalarTypeKind::Float;
+ break;
+ default:
+ llvm_unreachable("Unhandled type code!");
+ }
+ assert(ElementBitwidth != 0 && "Bad element bitwidth!");
+}
+
+void RVVType::applyModifier(StringRef Transformer) {
+ if (Transformer.empty())
+ return;
+ // Handle primitive type transformer
+ auto PType = Transformer.back();
+ switch (PType) {
+ case 'e':
+ Scale = 0;
+ break;
+ case 'v':
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case 'w':
+ ElementBitwidth *= 2;
+ LMUL *= 2;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case 'q':
+ ElementBitwidth *= 4;
+ LMUL *= 4;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case 'o':
+ ElementBitwidth *= 8;
+ LMUL *= 8;
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ case 'm':
+ ScalarType = ScalarTypeKind::Boolean;
+ Scale = LMUL.getScale(ElementBitwidth);
+ ElementBitwidth = 1;
+ break;
+ case '0':
+ ScalarType = ScalarTypeKind::Void;
+ break;
+ case 'z':
+ ScalarType = ScalarTypeKind::Size_t;
+ break;
+ case 't':
+ ScalarType = ScalarTypeKind::Ptr
diff _t;
+ break;
+ case 'u':
+ ScalarType = ScalarTypeKind::UnsignedLong;
+ break;
+ case 'l':
+ ScalarType = ScalarTypeKind::SignedLong;
+ break;
+ default:
+ llvm_unreachable("Illegal primitive type transformers!");
+ }
+ Transformer = Transformer.drop_back();
+
+ // Extract and compute complex type transformer. It can only appear one time.
+ if (Transformer.startswith("(")) {
+ size_t Idx = Transformer.find(')');
+ assert(Idx != StringRef::npos);
+ StringRef ComplexType = Transformer.slice(1, Idx);
+ Transformer = Transformer.drop_front(Idx + 1);
+ assert(!Transformer.contains('(') &&
+ "Only allow one complex type transformer");
+
+ auto UpdateAndCheckComplexProto = [&]() {
+ Scale = LMUL.getScale(ElementBitwidth);
+ const StringRef VectorPrototypes("vwqom");
+ if (!VectorPrototypes.contains(PType))
+ llvm_unreachable("Complex type transformer only supports vector type!");
+ if (Transformer.find_first_of("PCKWS") != StringRef::npos)
+ llvm_unreachable(
+ "Illegal type transformer for Complex type transformer");
+ };
+ auto ComputeFixedLog2LMUL =
+ [&](StringRef Value,
+ std::function<bool(const int32_t &, const int32_t &)> Compare) {
+ int32_t Log2LMUL;
+ Value.getAsInteger(10, Log2LMUL);
+ if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
+ ScalarType = Invalid;
+ return false;
+ }
+ // Update new LMUL
+ LMUL = LMULType(Log2LMUL);
+ UpdateAndCheckComplexProto();
+ return true;
+ };
+ auto ComplexTT = ComplexType.split(":");
+ if (ComplexTT.first == "Log2EEW") {
+ uint32_t Log2EEW;
+ ComplexTT.second.getAsInteger(10, Log2EEW);
+ // update new elmul = (eew/sew) * lmul
+ LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
+ // update new eew
+ ElementBitwidth = 1 << Log2EEW;
+ ScalarType = ScalarTypeKind::SignedInteger;
+ UpdateAndCheckComplexProto();
+ } else if (ComplexTT.first == "FixedSEW") {
+ uint32_t NewSEW;
+ ComplexTT.second.getAsInteger(10, NewSEW);
+ // Set invalid type if src and dst SEW are same.
+ if (ElementBitwidth == NewSEW) {
+ ScalarType = Invalid;
+ return;
+ }
+ // Update new SEW
+ ElementBitwidth = NewSEW;
+ UpdateAndCheckComplexProto();
+ } else if (ComplexTT.first == "LFixedLog2LMUL") {
+ // New LMUL should be larger than old
+ if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
+ return;
+ } else if (ComplexTT.first == "SFixedLog2LMUL") {
+ // New LMUL should be smaller than old
+ if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
+ return;
+ } else {
+ llvm_unreachable("Illegal complex type transformers!");
+ }
+ }
+
+ // Compute the remain type transformers
+ for (char I : Transformer) {
+ switch (I) {
+ case 'P':
+ if (IsConstant)
+ llvm_unreachable("'P' transformer cannot be used after 'C'");
+ if (IsPointer)
+ llvm_unreachable("'P' transformer cannot be used twice");
+ IsPointer = true;
+ break;
+ case 'C':
+ if (IsConstant)
+ llvm_unreachable("'C' transformer cannot be used twice");
+ IsConstant = true;
+ break;
+ case 'K':
+ IsImmediate = true;
+ break;
+ case 'U':
+ ScalarType = ScalarTypeKind::UnsignedInteger;
+ break;
+ case 'I':
+ ScalarType = ScalarTypeKind::SignedInteger;
+ break;
+ case 'F':
+ ScalarType = ScalarTypeKind::Float;
+ break;
+ case 'S':
+ LMUL = LMULType(0);
+ // Update ElementBitwidth need to update Scale too.
+ Scale = LMUL.getScale(ElementBitwidth);
+ break;
+ default:
+ llvm_unreachable("Illegal non-primitive type transformer!");
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// RVVIntrinsic implementation
+//===----------------------------------------------------------------------===//
+RVVIntrinsic::RVVIntrinsic(
+ StringRef NewName, StringRef Suffix, StringRef NewMangledName,
+ StringRef MangledSuffix, StringRef IRName, bool IsMasked,
+ bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
+ bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
+ const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
+ const std::vector<StringRef> &RequiredFeatures, unsigned NF)
+ : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
+ HasUnMaskedOverloaded(HasUnMaskedOverloaded),
+ HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
+ NF(NF) {
+
+ // Init BuiltinName, Name and MangledName
+ BuiltinName = NewName.str();
+ Name = BuiltinName;
+ if (NewMangledName.empty())
+ MangledName = NewName.split("_").first.str();
+ else
+ MangledName = NewMangledName.str();
+ if (!Suffix.empty())
+ Name += "_" + Suffix.str();
+ if (!MangledSuffix.empty())
+ MangledName += "_" + MangledSuffix.str();
+ if (IsMasked) {
+ BuiltinName += "_m";
+ Name += "_m";
+ }
+
+ // Init RISC-V extensions
+ for (const auto &T : OutInTypes) {
+ if (T->isFloatVector(16) || T->isFloat(16))
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh;
+ if (T->isFloatVector(32))
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32;
+ if (T->isFloatVector(64))
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64;
+ if (T->isVector(64))
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64;
+ }
+ for (auto Feature : RequiredFeatures) {
+ if (Feature == "RV64")
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64;
+ // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64
+ // require V.
+ if (Feature == "FullMultiply" &&
+ (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64))
+ RISCVPredefinedMacros |= RISCVPredefinedMacro::V;
+ }
+
+ // Init OutputType and InputTypes
+ OutputType = OutInTypes[0];
+ InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
+
+ // IntrinsicTypes is unmasked TA version index. Need to update it
+ // if there is merge operand (It is always in first operand).
+ IntrinsicTypes = NewIntrinsicTypes;
+ if ((IsMasked && HasMaskedOffOperand) ||
+ (!IsMasked && hasPassthruOperand())) {
+ for (auto &I : IntrinsicTypes) {
+ if (I >= 0)
+ I += NF;
+ }
+ }
+}
+
+std::string RVVIntrinsic::getBuiltinTypeStr() const {
+ std::string S;
+ S += OutputType->getBuiltinStr();
+ for (const auto &T : InputTypes) {
+ S += T->getBuiltinStr();
+ }
+ return S;
+}
+
+} // end namespace RISCV
+} // end namespace clang
diff --git a/clang/utils/TableGen/CMakeLists.txt b/clang/utils/TableGen/CMakeLists.txt
index 6379cc4e11e83..04aa72cde03e3 100644
--- a/clang/utils/TableGen/CMakeLists.txt
+++ b/clang/utils/TableGen/CMakeLists.txt
@@ -22,4 +22,7 @@ add_tablegen(clang-tblgen CLANG
SveEmitter.cpp
TableGen.cpp
)
+
+target_link_libraries(clang-tblgen PRIVATE clangSupport)
+
set_target_properties(clang-tblgen PROPERTIES FOLDER "Clang tablegenning")
diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp
index f26b1189c1e97..bd9e74f2f0cf7 100644
--- a/clang/utils/TableGen/RISCVVEmitter.cpp
+++ b/clang/utils/TableGen/RISCVVEmitter.cpp
@@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//
+#include "clang/Support/RISCVVIntrinsicUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringExtras.h"
@@ -25,206 +26,9 @@
#include <numeric>
using namespace llvm;
-using BasicType = char;
-using VScaleVal = Optional<unsigned>;
+using namespace clang::RISCV;
namespace {
-
-// Exponential LMUL
-struct LMULType {
- int Log2LMUL;
- LMULType(int Log2LMUL);
- // Return the C/C++ string representation of LMUL
- std::string str() const;
- Optional<unsigned> getScale(unsigned ElementBitwidth) const;
- void MulLog2LMUL(int Log2LMUL);
- LMULType &operator*=(uint32_t RHS);
-};
-
-// This class is compact representation of a valid and invalid RVVType.
-class RVVType {
- enum ScalarTypeKind : uint32_t {
- Void,
- Size_t,
- Ptr
diff _t,
- UnsignedLong,
- SignedLong,
- Boolean,
- SignedInteger,
- UnsignedInteger,
- Float,
- Invalid,
- };
- BasicType BT;
- ScalarTypeKind ScalarType = Invalid;
- LMULType LMUL;
- bool IsPointer = false;
- // IsConstant indices are "int", but have the constant expression.
- bool IsImmediate = false;
- // Const qualifier for pointer to const object or object of const type.
- bool IsConstant = false;
- unsigned ElementBitwidth = 0;
- VScaleVal Scale = 0;
- bool Valid;
-
- std::string BuiltinStr;
- std::string ClangBuiltinStr;
- std::string Str;
- std::string ShortStr;
-
-public:
- RVVType() : RVVType(BasicType(), 0, StringRef()) {}
- RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
-
- // Return the string representation of a type, which is an encoded string for
- // passing to the BUILTIN() macro in Builtins.def.
- const std::string &getBuiltinStr() const { return BuiltinStr; }
-
- // Return the clang builtin type for RVV vector type which are used in the
- // riscv_vector.h header file.
- const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
-
- // Return the C/C++ string representation of a type for use in the
- // riscv_vector.h header file.
- const std::string &getTypeStr() const { return Str; }
-
- // Return the short name of a type for C/C++ name suffix.
- const std::string &getShortStr() {
- // Not all types are used in short name, so compute the short name by
- // demanded.
- if (ShortStr.empty())
- initShortStr();
- return ShortStr;
- }
-
- bool isValid() const { return Valid; }
- bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
- bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
- bool isVector(unsigned Width) const {
- return isVector() && ElementBitwidth == Width;
- }
- bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
- bool isSignedInteger() const {
- return ScalarType == ScalarTypeKind::SignedInteger;
- }
- bool isFloatVector(unsigned Width) const {
- return isVector() && isFloat() && ElementBitwidth == Width;
- }
- bool isFloat(unsigned Width) const {
- return isFloat() && ElementBitwidth == Width;
- }
-
-private:
- // Verify RVV vector type and set Valid.
- bool verifyType() const;
-
- // Creates a type based on basic types of TypeRange
- void applyBasicType();
-
- // Applies a prototype modifier to the current type. The result maybe an
- // invalid type.
- void applyModifier(StringRef prototype);
-
- // Compute and record a string for legal type.
- void initBuiltinStr();
- // Compute and record a builtin RVV vector type string.
- void initClangBuiltinStr();
- // Compute and record a type string for used in the header.
- void initTypeStr();
- // Compute and record a short name of a type for C/C++ name suffix.
- void initShortStr();
-};
-
-using RVVTypePtr = RVVType *;
-using RVVTypes = std::vector<RVVTypePtr>;
-using RISCVPredefinedMacroT = uint8_t;
-
-enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
- Basic = 0,
- V = 1 << 1,
- Zvfh = 1 << 2,
- RV64 = 1 << 3,
- VectorMaxELen64 = 1 << 4,
- VectorMaxELenFp32 = 1 << 5,
- VectorMaxELenFp64 = 1 << 6,
-};
-
-enum PolicyScheme : uint8_t {
- SchemeNone,
- HasPassthruOperand,
- HasPolicyOperand,
-};
-
-// TODO refactor RVVIntrinsic class design after support all intrinsic
-// combination. This represents an instantiation of an intrinsic with a
-// particular type and prototype
-class RVVIntrinsic {
-
-private:
- std::string BuiltinName; // Builtin name
- std::string Name; // C intrinsic name.
- std::string MangledName;
- std::string IRName;
- bool IsMasked;
- bool HasVL;
- PolicyScheme Scheme;
- bool HasUnMaskedOverloaded;
- bool HasBuiltinAlias;
- std::string ManualCodegen;
- RVVTypePtr OutputType; // Builtin output type
- RVVTypes InputTypes; // Builtin input types
- // The types we use to obtain the specific LLVM intrinsic. They are index of
- // InputTypes. -1 means the return type.
- std::vector<int64_t> IntrinsicTypes;
- RISCVPredefinedMacroT RISCVPredefinedMacros = 0;
- unsigned NF = 1;
-
-public:
- RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
- StringRef MangledSuffix, StringRef IRName, bool IsMasked,
- bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
- bool HasUnMaskedOverloaded, bool HasBuiltinAlias,
- StringRef ManualCodegen, const RVVTypes &Types,
- const std::vector<int64_t> &IntrinsicTypes,
- const std::vector<StringRef> &RequiredFeatures, unsigned NF);
- ~RVVIntrinsic() = default;
-
- StringRef getBuiltinName() const { return BuiltinName; }
- StringRef getName() const { return Name; }
- StringRef getMangledName() const { return MangledName; }
- bool hasVL() const { return HasVL; }
- bool hasPolicy() const { return Scheme != SchemeNone; }
- bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; }
- bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; }
- bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; }
- bool hasBuiltinAlias() const { return HasBuiltinAlias; }
- bool hasManualCodegen() const { return !ManualCodegen.empty(); }
- bool isMasked() const { return IsMasked; }
- StringRef getIRName() const { return IRName; }
- StringRef getManualCodegen() const { return ManualCodegen; }
- PolicyScheme getPolicyScheme() const { return Scheme; }
- RISCVPredefinedMacroT getRISCVPredefinedMacros() const {
- return RISCVPredefinedMacros;
- }
- unsigned getNF() const { return NF; }
- const std::vector<int64_t> &getIntrinsicTypes() const {
- return IntrinsicTypes;
- }
-
- // Return the type string for a BUILTIN() macro in Builtins.def.
- std::string getBuiltinTypeStr() const;
-
- // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
- // init the RVVIntrinsic ID and IntrinsicTypes.
- void emitCodeGenSwitchBody(raw_ostream &o) const;
-
- // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
- void emitIntrinsicFuncDef(raw_ostream &o) const;
-
- // Emit the mangled function definition.
- void emitMangledFuncDef(raw_ostream &o) const;
-};
-
class RVVEmitter {
private:
RecordKeeper &Records;
@@ -277,602 +81,31 @@ class RVVEmitter {
} // namespace
-//===----------------------------------------------------------------------===//
-// Type implementation
-//===----------------------------------------------------------------------===//
-
-LMULType::LMULType(int NewLog2LMUL) {
- // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
- assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
- Log2LMUL = NewLog2LMUL;
-}
-
-std::string LMULType::str() const {
- if (Log2LMUL < 0)
- return "mf" + utostr(1ULL << (-Log2LMUL));
- return "m" + utostr(1ULL << Log2LMUL);
-}
-
-VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
- int Log2ScaleResult = 0;
- switch (ElementBitwidth) {
- default:
- break;
- case 8:
- Log2ScaleResult = Log2LMUL + 3;
- break;
- case 16:
- Log2ScaleResult = Log2LMUL + 2;
- break;
- case 32:
- Log2ScaleResult = Log2LMUL + 1;
- break;
- case 64:
- Log2ScaleResult = Log2LMUL;
- break;
- }
- // Illegal vscale result would be less than 1
- if (Log2ScaleResult < 0)
- return llvm::None;
- return 1 << Log2ScaleResult;
-}
-
-void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
-
-LMULType &LMULType::operator*=(uint32_t RHS) {
- assert(isPowerOf2_32(RHS));
- this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
- return *this;
-}
-
-RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
- : BT(BT), LMUL(LMULType(Log2LMUL)) {
- applyBasicType();
- applyModifier(prototype);
- Valid = verifyType();
- if (Valid) {
- initBuiltinStr();
- initTypeStr();
- if (isVector()) {
- initClangBuiltinStr();
- }
- }
-}
-
-// clang-format off
-// boolean type are encoded the ratio of n (SEW/LMUL)
-// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
-// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
-// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
-
-// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
-// -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
-// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
-// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
-// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
-// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
-// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
-// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
-// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
-// clang-format on
-
-bool RVVType::verifyType() const {
- if (ScalarType == Invalid)
- return false;
- if (isScalar())
- return true;
- if (!Scale.hasValue())
- return false;
- if (isFloat() && ElementBitwidth == 8)
- return false;
- unsigned V = Scale.getValue();
- switch (ElementBitwidth) {
- case 1:
- case 8:
- // Check Scale is 1,2,4,8,16,32,64
- return (V <= 64 && isPowerOf2_32(V));
- case 16:
- // Check Scale is 1,2,4,8,16,32
- return (V <= 32 && isPowerOf2_32(V));
- case 32:
- // Check Scale is 1,2,4,8,16
- return (V <= 16 && isPowerOf2_32(V));
- case 64:
- // Check Scale is 1,2,4,8
- return (V <= 8 && isPowerOf2_32(V));
- }
- return false;
-}
-
-void RVVType::initBuiltinStr() {
- assert(isValid() && "RVVType is invalid");
- switch (ScalarType) {
- case ScalarTypeKind::Void:
- BuiltinStr = "v";
- return;
- case ScalarTypeKind::Size_t:
- BuiltinStr = "z";
- if (IsImmediate)
- BuiltinStr = "I" + BuiltinStr;
- if (IsPointer)
- BuiltinStr += "*";
- return;
- case ScalarTypeKind::Ptr
diff _t:
- BuiltinStr = "Y";
- return;
- case ScalarTypeKind::UnsignedLong:
- BuiltinStr = "ULi";
- return;
- case ScalarTypeKind::SignedLong:
- BuiltinStr = "Li";
- return;
- case ScalarTypeKind::Boolean:
- assert(ElementBitwidth == 1);
- BuiltinStr += "b";
- break;
- case ScalarTypeKind::SignedInteger:
- case ScalarTypeKind::UnsignedInteger:
- switch (ElementBitwidth) {
- case 8:
- BuiltinStr += "c";
- break;
- case 16:
- BuiltinStr += "s";
- break;
- case 32:
- BuiltinStr += "i";
- break;
- case 64:
- BuiltinStr += "Wi";
- break;
- default:
- llvm_unreachable("Unhandled ElementBitwidth!");
- }
- if (isSignedInteger())
- BuiltinStr = "S" + BuiltinStr;
- else
- BuiltinStr = "U" + BuiltinStr;
- break;
- case ScalarTypeKind::Float:
- switch (ElementBitwidth) {
- case 16:
- BuiltinStr += "x";
- break;
- case 32:
- BuiltinStr += "f";
- break;
- case 64:
- BuiltinStr += "d";
- break;
- default:
- llvm_unreachable("Unhandled ElementBitwidth!");
- }
- break;
- default:
- llvm_unreachable("ScalarType is invalid!");
- }
- if (IsImmediate)
- BuiltinStr = "I" + BuiltinStr;
- if (isScalar()) {
- if (IsConstant)
- BuiltinStr += "C";
- if (IsPointer)
- BuiltinStr += "*";
- return;
- }
- BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
- // Pointer to vector types. Defined for segment load intrinsics.
- // segment load intrinsics have pointer type arguments to store the loaded
- // vector values.
- if (IsPointer)
- BuiltinStr += "*";
-}
-
-void RVVType::initClangBuiltinStr() {
- assert(isValid() && "RVVType is invalid");
- assert(isVector() && "Handle Vector type only");
-
- ClangBuiltinStr = "__rvv_";
- switch (ScalarType) {
- case ScalarTypeKind::Boolean:
- ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
- return;
- case ScalarTypeKind::Float:
- ClangBuiltinStr += "float";
- break;
- case ScalarTypeKind::SignedInteger:
- ClangBuiltinStr += "int";
- break;
- case ScalarTypeKind::UnsignedInteger:
- ClangBuiltinStr += "uint";
- break;
- default:
- llvm_unreachable("ScalarTypeKind is invalid");
- }
- ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
-}
-
-void RVVType::initTypeStr() {
- assert(isValid() && "RVVType is invalid");
-
- if (IsConstant)
- Str += "const ";
-
- auto getTypeString = [&](StringRef TypeStr) {
- if (isScalar())
- return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
- return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
- .str();
- };
-
- switch (ScalarType) {
- case ScalarTypeKind::Void:
- Str = "void";
- return;
- case ScalarTypeKind::Size_t:
- Str = "size_t";
- if (IsPointer)
- Str += " *";
- return;
- case ScalarTypeKind::Ptr
diff _t:
- Str = "ptr
diff _t";
- return;
- case ScalarTypeKind::UnsignedLong:
- Str = "unsigned long";
- return;
- case ScalarTypeKind::SignedLong:
- Str = "long";
- return;
- case ScalarTypeKind::Boolean:
- if (isScalar())
- Str += "bool";
- else
- // Vector bool is special case, the formulate is
- // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
- Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
- break;
- case ScalarTypeKind::Float:
- if (isScalar()) {
- if (ElementBitwidth == 64)
- Str += "double";
- else if (ElementBitwidth == 32)
- Str += "float";
- else if (ElementBitwidth == 16)
- Str += "_Float16";
- else
- llvm_unreachable("Unhandled floating type.");
- } else
- Str += getTypeString("float");
- break;
- case ScalarTypeKind::SignedInteger:
- Str += getTypeString("int");
- break;
- case ScalarTypeKind::UnsignedInteger:
- Str += getTypeString("uint");
- break;
- default:
- llvm_unreachable("ScalarType is invalid!");
- }
- if (IsPointer)
- Str += " *";
-}
-
-void RVVType::initShortStr() {
- switch (ScalarType) {
- case ScalarTypeKind::Boolean:
- assert(isVector());
- ShortStr = "b" + utostr(64 / Scale.getValue());
- return;
- case ScalarTypeKind::Float:
- ShortStr = "f" + utostr(ElementBitwidth);
- break;
- case ScalarTypeKind::SignedInteger:
- ShortStr = "i" + utostr(ElementBitwidth);
- break;
- case ScalarTypeKind::UnsignedInteger:
- ShortStr = "u" + utostr(ElementBitwidth);
- break;
- default:
- PrintFatalError("Unhandled case!");
- }
- if (isVector())
- ShortStr += LMUL.str();
-}
-
-void RVVType::applyBasicType() {
- switch (BT) {
- case 'c':
- ElementBitwidth = 8;
- ScalarType = ScalarTypeKind::SignedInteger;
- break;
- case 's':
- ElementBitwidth = 16;
- ScalarType = ScalarTypeKind::SignedInteger;
- break;
- case 'i':
- ElementBitwidth = 32;
- ScalarType = ScalarTypeKind::SignedInteger;
- break;
- case 'l':
- ElementBitwidth = 64;
- ScalarType = ScalarTypeKind::SignedInteger;
- break;
- case 'x':
- ElementBitwidth = 16;
- ScalarType = ScalarTypeKind::Float;
- break;
- case 'f':
- ElementBitwidth = 32;
- ScalarType = ScalarTypeKind::Float;
- break;
- case 'd':
- ElementBitwidth = 64;
- ScalarType = ScalarTypeKind::Float;
- break;
- default:
- PrintFatalError("Unhandled type code!");
- }
- assert(ElementBitwidth != 0 && "Bad element bitwidth!");
-}
-
-void RVVType::applyModifier(StringRef Transformer) {
- if (Transformer.empty())
- return;
- // Handle primitive type transformer
- auto PType = Transformer.back();
- switch (PType) {
- case 'e':
- Scale = 0;
- break;
- case 'v':
- Scale = LMUL.getScale(ElementBitwidth);
- break;
- case 'w':
- ElementBitwidth *= 2;
- LMUL *= 2;
- Scale = LMUL.getScale(ElementBitwidth);
- break;
- case 'q':
- ElementBitwidth *= 4;
- LMUL *= 4;
- Scale = LMUL.getScale(ElementBitwidth);
- break;
- case 'o':
- ElementBitwidth *= 8;
- LMUL *= 8;
- Scale = LMUL.getScale(ElementBitwidth);
- break;
- case 'm':
- ScalarType = ScalarTypeKind::Boolean;
- Scale = LMUL.getScale(ElementBitwidth);
- ElementBitwidth = 1;
- break;
- case '0':
- ScalarType = ScalarTypeKind::Void;
- break;
- case 'z':
- ScalarType = ScalarTypeKind::Size_t;
- break;
- case 't':
- ScalarType = ScalarTypeKind::Ptr
diff _t;
- break;
- case 'u':
- ScalarType = ScalarTypeKind::UnsignedLong;
- break;
- case 'l':
- ScalarType = ScalarTypeKind::SignedLong;
- break;
- default:
- PrintFatalError("Illegal primitive type transformers!");
- }
- Transformer = Transformer.drop_back();
-
- // Extract and compute complex type transformer. It can only appear one time.
- if (Transformer.startswith("(")) {
- size_t Idx = Transformer.find(')');
- assert(Idx != StringRef::npos);
- StringRef ComplexType = Transformer.slice(1, Idx);
- Transformer = Transformer.drop_front(Idx + 1);
- assert(!Transformer.contains('(') &&
- "Only allow one complex type transformer");
-
- auto UpdateAndCheckComplexProto = [&]() {
- Scale = LMUL.getScale(ElementBitwidth);
- const StringRef VectorPrototypes("vwqom");
- if (!VectorPrototypes.contains(PType))
- PrintFatalError("Complex type transformer only supports vector type!");
- if (Transformer.find_first_of("PCKWS") != StringRef::npos)
- PrintFatalError(
- "Illegal type transformer for Complex type transformer");
- };
- auto ComputeFixedLog2LMUL =
- [&](StringRef Value,
- std::function<bool(const int32_t &, const int32_t &)> Compare) {
- int32_t Log2LMUL;
- Value.getAsInteger(10, Log2LMUL);
- if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
- ScalarType = Invalid;
- return false;
- }
- // Update new LMUL
- LMUL = LMULType(Log2LMUL);
- UpdateAndCheckComplexProto();
- return true;
- };
- auto ComplexTT = ComplexType.split(":");
- if (ComplexTT.first == "Log2EEW") {
- uint32_t Log2EEW;
- ComplexTT.second.getAsInteger(10, Log2EEW);
- // update new elmul = (eew/sew) * lmul
- LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
- // update new eew
- ElementBitwidth = 1 << Log2EEW;
- ScalarType = ScalarTypeKind::SignedInteger;
- UpdateAndCheckComplexProto();
- } else if (ComplexTT.first == "FixedSEW") {
- uint32_t NewSEW;
- ComplexTT.second.getAsInteger(10, NewSEW);
- // Set invalid type if src and dst SEW are same.
- if (ElementBitwidth == NewSEW) {
- ScalarType = Invalid;
- return;
- }
- // Update new SEW
- ElementBitwidth = NewSEW;
- UpdateAndCheckComplexProto();
- } else if (ComplexTT.first == "LFixedLog2LMUL") {
- // New LMUL should be larger than old
- if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
- return;
- } else if (ComplexTT.first == "SFixedLog2LMUL") {
- // New LMUL should be smaller than old
- if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
- return;
- } else {
- PrintFatalError("Illegal complex type transformers!");
- }
- }
-
- // Compute the remain type transformers
- for (char I : Transformer) {
- switch (I) {
- case 'P':
- if (IsConstant)
- PrintFatalError("'P' transformer cannot be used after 'C'");
- if (IsPointer)
- PrintFatalError("'P' transformer cannot be used twice");
- IsPointer = true;
- break;
- case 'C':
- if (IsConstant)
- PrintFatalError("'C' transformer cannot be used twice");
- IsConstant = true;
- break;
- case 'K':
- IsImmediate = true;
- break;
- case 'U':
- ScalarType = ScalarTypeKind::UnsignedInteger;
- break;
- case 'I':
- ScalarType = ScalarTypeKind::SignedInteger;
- break;
- case 'F':
- ScalarType = ScalarTypeKind::Float;
- break;
- case 'S':
- LMUL = LMULType(0);
- // Update ElementBitwidth need to update Scale too.
- Scale = LMUL.getScale(ElementBitwidth);
- break;
- default:
- PrintFatalError("Illegal non-primitive type transformer!");
- }
- }
-}
-
-//===----------------------------------------------------------------------===//
-// RVVIntrinsic implementation
-//===----------------------------------------------------------------------===//
-RVVIntrinsic::RVVIntrinsic(
- StringRef NewName, StringRef Suffix, StringRef NewMangledName,
- StringRef MangledSuffix, StringRef IRName, bool IsMasked,
- bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
- bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
- const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
- const std::vector<StringRef> &RequiredFeatures, unsigned NF)
- : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
- HasUnMaskedOverloaded(HasUnMaskedOverloaded),
- HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
- NF(NF) {
-
- // Init BuiltinName, Name and MangledName
- BuiltinName = NewName.str();
- Name = BuiltinName;
- if (NewMangledName.empty())
- MangledName = NewName.split("_").first.str();
- else
- MangledName = NewMangledName.str();
- if (!Suffix.empty())
- Name += "_" + Suffix.str();
- if (!MangledSuffix.empty())
- MangledName += "_" + MangledSuffix.str();
- if (IsMasked) {
- BuiltinName += "_m";
- Name += "_m";
- }
-
- // Init RISC-V extensions
- for (const auto &T : OutInTypes) {
- if (T->isFloatVector(16) || T->isFloat(16))
- RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh;
- if (T->isFloatVector(32))
- RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32;
- if (T->isFloatVector(64))
- RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64;
- if (T->isVector(64))
- RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64;
- }
- for (auto Feature : RequiredFeatures) {
- if (Feature == "RV64")
- RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64;
- // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64
- // require V.
- if (Feature == "FullMultiply" &&
- (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64))
- RISCVPredefinedMacros |= RISCVPredefinedMacro::V;
- }
-
- // Init OutputType and InputTypes
- OutputType = OutInTypes[0];
- InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
-
- // IntrinsicTypes is unmasked TA version index. Need to update it
- // if there is merge operand (It is always in first operand).
- IntrinsicTypes = NewIntrinsicTypes;
- if ((IsMasked && HasMaskedOffOperand) ||
- (!IsMasked && hasPassthruOperand())) {
- for (auto &I : IntrinsicTypes) {
- if (I >= 0)
- I += NF;
- }
- }
-}
-
-std::string RVVIntrinsic::getBuiltinTypeStr() const {
- std::string S;
- S += OutputType->getBuiltinStr();
- for (const auto &T : InputTypes) {
- S += T->getBuiltinStr();
- }
- return S;
-}
-
-void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
- if (!getIRName().empty())
- OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n";
- if (NF >= 2)
- OS << " NF = " + utostr(getNF()) + ";\n";
- if (hasManualCodegen()) {
- OS << ManualCodegen;
+void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) {
+ if (!RVVI->getIRName().empty())
+ OS << " ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n";
+ if (RVVI->getNF() >= 2)
+ OS << " NF = " + utostr(RVVI->getNF()) + ";\n";
+ if (RVVI->hasManualCodegen()) {
+ OS << RVVI->getManualCodegen();
OS << "break;\n";
return;
}
- if (isMasked()) {
- if (hasVL()) {
+ if (RVVI->isMasked()) {
+ if (RVVI->hasVL()) {
OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
- if (hasPolicyOperand())
+ if (RVVI->hasPolicyOperand())
OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType(),"
" TAIL_UNDISTURBED));\n";
} else {
OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
}
} else {
- if (hasPolicyOperand())
+ if (RVVI->hasPolicyOperand())
OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType(), "
"TAIL_UNDISTURBED));\n";
- else if (hasPassthruOperand()) {
+ else if (RVVI->hasPassthruOperand()) {
OS << " Ops.push_back(llvm::UndefValue::get(ResultType));\n";
OS << " std::rotate(Ops.rbegin(), Ops.rbegin() + 1, Ops.rend());\n";
}
@@ -880,7 +113,7 @@ void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
OS << " IntrinsicTypes = {";
ListSeparator LS;
- for (const auto &Idx : IntrinsicTypes) {
+ for (const auto &Idx : RVVI->getIntrinsicTypes()) {
if (Idx == -1)
OS << LS << "ResultType";
else
@@ -889,17 +122,18 @@ void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
// VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
// always last operand.
- if (hasVL())
+ if (RVVI->hasVL())
OS << ", Ops.back()->getType()";
OS << "};\n";
OS << " break;\n";
}
-void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const {
+void emitIntrinsicFuncDef(const RVVIntrinsic &RVVI, raw_ostream &OS) {
OS << "__attribute__((__clang_builtin_alias__(";
- OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
- OS << OutputType->getTypeStr() << " " << getName() << "(";
+ OS << "__builtin_rvv_" << RVVI.getBuiltinName() << ")))\n";
+ OS << RVVI.getOutputType()->getTypeStr() << " " << RVVI.getName() << "(";
// Emit function arguments
+ const RVVTypes &InputTypes = RVVI.getInputTypes();
if (!InputTypes.empty()) {
ListSeparator LS;
for (unsigned i = 0; i < InputTypes.size(); ++i)
@@ -908,11 +142,13 @@ void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const {
OS << ");\n";
}
-void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
+void emitMangledFuncDef(const RVVIntrinsic &RVVI, raw_ostream &OS) {
OS << "__attribute__((__clang_builtin_alias__(";
- OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
- OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
+ OS << "__builtin_rvv_" << RVVI.getBuiltinName() << ")))\n";
+ OS << RVVI.getOutputType()->getTypeStr() << " " << RVVI.getMangledName()
+ << "(";
// Emit function arguments
+ const RVVTypes &InputTypes = RVVI.getInputTypes();
if (!InputTypes.empty()) {
ListSeparator LS;
for (unsigned i = 0; i < InputTypes.size(); ++i)
@@ -1016,7 +252,7 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
// Print intrinsic functions with macro
emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
OS << "__rvv_ai ";
- Inst.emitIntrinsicFuncDef(OS);
+ emitIntrinsicFuncDef(Inst, OS);
});
OS << "#undef __rvv_ai\n\n";
@@ -1031,7 +267,7 @@ void RVVEmitter::createHeader(raw_ostream &OS) {
if (!Inst.isMasked() && !Inst.hasUnMaskedOverloaded())
return;
OS << "__rvv_aio ";
- Inst.emitMangledFuncDef(OS);
+ emitMangledFuncDef(Inst, OS);
});
OS << "#undef __rvv_aio\n";
@@ -1092,7 +328,7 @@ void RVVEmitter::createCodeGen(raw_ostream &OS) {
StringRef CurIRName = Def->getIRName();
if (CurIRName != PrevDef->getIRName() ||
(Def->getManualCodegen() != PrevDef->getManualCodegen())) {
- PrevDef->emitCodeGenSwitchBody(OS);
+ emitCodeGenSwitchBody(PrevDef, OS);
}
PrevDef = Def.get();
@@ -1119,7 +355,7 @@ void RVVEmitter::createCodeGen(raw_ostream &OS) {
else if (P.first->second->getIntrinsicTypes() != Def->getIntrinsicTypes())
PrintFatalError("Builtin with same name has
diff erent IntrinsicTypes");
}
- Defs.back()->emitCodeGenSwitchBody(OS);
+ emitCodeGenSwitchBody(Defs.back().get(), OS);
OS << "\n";
}
More information about the cfe-commits
mailing list