[llvm] ad57e10 - [RISCV][NFC] Moving RVV intrinsic type related util to llvm/Support

Kito Cheng via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 27 23:35:35 PDT 2022


Author: Kito Cheng
Date: 2022-03-28T14:35:28+08:00
New Revision: ad57e10dbca2fdeff1448afc0aa1cf23d6df8736

URL: https://github.com/llvm/llvm-project/commit/ad57e10dbca2fdeff1448afc0aa1cf23d6df8736
DIFF: https://github.com/llvm/llvm-project/commit/ad57e10dbca2fdeff1448afc0aa1cf23d6df8736.diff

LOG: [RISCV][NFC] Moving RVV intrinsic type related util to llvm/Support

This patch is split from https://reviews.llvm.org/D111617, we need those
stuffs on clang, so must moving those stuff to llvm/Support.

Reviewed By: khchen

Differential Revision: https://reviews.llvm.org/D121984

Added: 
    llvm/include/llvm/Support/RISCVVIntrinsicUtils.h
    llvm/lib/Support/RISCVVIntrinsicUtils.cpp

Modified: 
    clang/utils/TableGen/RISCVVEmitter.cpp
    llvm/lib/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp
index f26b1189c1e97..accced0292f0e 100644
--- a/clang/utils/TableGen/RISCVVEmitter.cpp
+++ b/clang/utils/TableGen/RISCVVEmitter.cpp
@@ -20,211 +20,15 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Support/RISCVVIntrinsicUtils.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 #include <numeric>
 
 using namespace llvm;
-using BasicType = char;
-using VScaleVal = Optional<unsigned>;
+using namespace llvm::RISCV;
 
 namespace {
-
-// Exponential LMUL
-struct LMULType {
-  int Log2LMUL;
-  LMULType(int Log2LMUL);
-  // Return the C/C++ string representation of LMUL
-  std::string str() const;
-  Optional<unsigned> getScale(unsigned ElementBitwidth) const;
-  void MulLog2LMUL(int Log2LMUL);
-  LMULType &operator*=(uint32_t RHS);
-};
-
-// This class is compact representation of a valid and invalid RVVType.
-class RVVType {
-  enum ScalarTypeKind : uint32_t {
-    Void,
-    Size_t,
-    Ptr
diff _t,
-    UnsignedLong,
-    SignedLong,
-    Boolean,
-    SignedInteger,
-    UnsignedInteger,
-    Float,
-    Invalid,
-  };
-  BasicType BT;
-  ScalarTypeKind ScalarType = Invalid;
-  LMULType LMUL;
-  bool IsPointer = false;
-  // IsConstant indices are "int", but have the constant expression.
-  bool IsImmediate = false;
-  // Const qualifier for pointer to const object or object of const type.
-  bool IsConstant = false;
-  unsigned ElementBitwidth = 0;
-  VScaleVal Scale = 0;
-  bool Valid;
-
-  std::string BuiltinStr;
-  std::string ClangBuiltinStr;
-  std::string Str;
-  std::string ShortStr;
-
-public:
-  RVVType() : RVVType(BasicType(), 0, StringRef()) {}
-  RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
-
-  // Return the string representation of a type, which is an encoded string for
-  // passing to the BUILTIN() macro in Builtins.def.
-  const std::string &getBuiltinStr() const { return BuiltinStr; }
-
-  // Return the clang builtin type for RVV vector type which are used in the
-  // riscv_vector.h header file.
-  const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
-
-  // Return the C/C++ string representation of a type for use in the
-  // riscv_vector.h header file.
-  const std::string &getTypeStr() const { return Str; }
-
-  // Return the short name of a type for C/C++ name suffix.
-  const std::string &getShortStr() {
-    // Not all types are used in short name, so compute the short name by
-    // demanded.
-    if (ShortStr.empty())
-      initShortStr();
-    return ShortStr;
-  }
-
-  bool isValid() const { return Valid; }
-  bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
-  bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
-  bool isVector(unsigned Width) const {
-    return isVector() && ElementBitwidth == Width;
-  }
-  bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
-  bool isSignedInteger() const {
-    return ScalarType == ScalarTypeKind::SignedInteger;
-  }
-  bool isFloatVector(unsigned Width) const {
-    return isVector() && isFloat() && ElementBitwidth == Width;
-  }
-  bool isFloat(unsigned Width) const {
-    return isFloat() && ElementBitwidth == Width;
-  }
-
-private:
-  // Verify RVV vector type and set Valid.
-  bool verifyType() const;
-
-  // Creates a type based on basic types of TypeRange
-  void applyBasicType();
-
-  // Applies a prototype modifier to the current type. The result maybe an
-  // invalid type.
-  void applyModifier(StringRef prototype);
-
-  // Compute and record a string for legal type.
-  void initBuiltinStr();
-  // Compute and record a builtin RVV vector type string.
-  void initClangBuiltinStr();
-  // Compute and record a type string for used in the header.
-  void initTypeStr();
-  // Compute and record a short name of a type for C/C++ name suffix.
-  void initShortStr();
-};
-
-using RVVTypePtr = RVVType *;
-using RVVTypes = std::vector<RVVTypePtr>;
-using RISCVPredefinedMacroT = uint8_t;
-
-enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
-  Basic = 0,
-  V = 1 << 1,
-  Zvfh = 1 << 2,
-  RV64 = 1 << 3,
-  VectorMaxELen64 = 1 << 4,
-  VectorMaxELenFp32 = 1 << 5,
-  VectorMaxELenFp64 = 1 << 6,
-};
-
-enum PolicyScheme : uint8_t {
-  SchemeNone,
-  HasPassthruOperand,
-  HasPolicyOperand,
-};
-
-// TODO refactor RVVIntrinsic class design after support all intrinsic
-// combination. This represents an instantiation of an intrinsic with a
-// particular type and prototype
-class RVVIntrinsic {
-
-private:
-  std::string BuiltinName; // Builtin name
-  std::string Name;        // C intrinsic name.
-  std::string MangledName;
-  std::string IRName;
-  bool IsMasked;
-  bool HasVL;
-  PolicyScheme Scheme;
-  bool HasUnMaskedOverloaded;
-  bool HasBuiltinAlias;
-  std::string ManualCodegen;
-  RVVTypePtr OutputType; // Builtin output type
-  RVVTypes InputTypes;   // Builtin input types
-  // The types we use to obtain the specific LLVM intrinsic. They are index of
-  // InputTypes. -1 means the return type.
-  std::vector<int64_t> IntrinsicTypes;
-  RISCVPredefinedMacroT RISCVPredefinedMacros = 0;
-  unsigned NF = 1;
-
-public:
-  RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
-               StringRef MangledSuffix, StringRef IRName, bool IsMasked,
-               bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
-               bool HasUnMaskedOverloaded, bool HasBuiltinAlias,
-               StringRef ManualCodegen, const RVVTypes &Types,
-               const std::vector<int64_t> &IntrinsicTypes,
-               const std::vector<StringRef> &RequiredFeatures, unsigned NF);
-  ~RVVIntrinsic() = default;
-
-  StringRef getBuiltinName() const { return BuiltinName; }
-  StringRef getName() const { return Name; }
-  StringRef getMangledName() const { return MangledName; }
-  bool hasVL() const { return HasVL; }
-  bool hasPolicy() const { return Scheme != SchemeNone; }
-  bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; }
-  bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; }
-  bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; }
-  bool hasBuiltinAlias() const { return HasBuiltinAlias; }
-  bool hasManualCodegen() const { return !ManualCodegen.empty(); }
-  bool isMasked() const { return IsMasked; }
-  StringRef getIRName() const { return IRName; }
-  StringRef getManualCodegen() const { return ManualCodegen; }
-  PolicyScheme getPolicyScheme() const { return Scheme; }
-  RISCVPredefinedMacroT getRISCVPredefinedMacros() const {
-    return RISCVPredefinedMacros;
-  }
-  unsigned getNF() const { return NF; }
-  const std::vector<int64_t> &getIntrinsicTypes() const {
-    return IntrinsicTypes;
-  }
-
-  // Return the type string for a BUILTIN() macro in Builtins.def.
-  std::string getBuiltinTypeStr() const;
-
-  // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
-  // init the RVVIntrinsic ID and IntrinsicTypes.
-  void emitCodeGenSwitchBody(raw_ostream &o) const;
-
-  // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
-  void emitIntrinsicFuncDef(raw_ostream &o) const;
-
-  // Emit the mangled function definition.
-  void emitMangledFuncDef(raw_ostream &o) const;
-};
-
 class RVVEmitter {
 private:
   RecordKeeper &Records;

diff  --git a/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h b/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h
new file mode 100644
index 0000000000000..d47c7ac181b5d
--- /dev/null
+++ b/llvm/include/llvm/Support/RISCVVIntrinsicUtils.h
@@ -0,0 +1,225 @@
+//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_RISCVVINTRINSICUTILS_H
+#define LLVM_SUPPORT_RISCVVINTRINSICUTILS_H
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include <cstdint>
+#include <string>
+#include <vector>
+
+using namespace llvm;
+
+namespace llvm {
+namespace RISCV {
+
+using BasicType = char;
+using VScaleVal = Optional<unsigned>;
+
+// Exponential LMUL
+struct LMULType {
+  int Log2LMUL;
+  LMULType(int Log2LMUL);
+  // Return the C/C++ string representation of LMUL
+  std::string str() const;
+  Optional<unsigned> getScale(unsigned ElementBitwidth) const;
+  void MulLog2LMUL(int Log2LMUL);
+  LMULType &operator*=(uint32_t RHS);
+};
+
+// This class is compact representation of a valid and invalid RVVType.
+class RVVType {
+  enum ScalarTypeKind : uint32_t {
+    Void,
+    Size_t,
+    Ptr
diff _t,
+    UnsignedLong,
+    SignedLong,
+    Boolean,
+    SignedInteger,
+    UnsignedInteger,
+    Float,
+    Invalid,
+  };
+  BasicType BT;
+  ScalarTypeKind ScalarType = Invalid;
+  LMULType LMUL;
+  bool IsPointer = false;
+  // IsConstant indices are "int", but have the constant expression.
+  bool IsImmediate = false;
+  // Const qualifier for pointer to const object or object of const type.
+  bool IsConstant = false;
+  unsigned ElementBitwidth = 0;
+  VScaleVal Scale = 0;
+  bool Valid;
+
+  std::string BuiltinStr;
+  std::string ClangBuiltinStr;
+  std::string Str;
+  std::string ShortStr;
+
+public:
+  RVVType() : RVVType(BasicType(), 0, StringRef()) {}
+  RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
+
+  // Return the string representation of a type, which is an encoded string for
+  // passing to the BUILTIN() macro in Builtins.def.
+  const std::string &getBuiltinStr() const { return BuiltinStr; }
+
+  // Return the clang builtin type for RVV vector type which are used in the
+  // riscv_vector.h header file.
+  const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
+
+  // Return the C/C++ string representation of a type for use in the
+  // riscv_vector.h header file.
+  const std::string &getTypeStr() const { return Str; }
+
+  // Return the short name of a type for C/C++ name suffix.
+  const std::string &getShortStr() {
+    // Not all types are used in short name, so compute the short name by
+    // demanded.
+    if (ShortStr.empty())
+      initShortStr();
+    return ShortStr;
+  }
+
+  bool isValid() const { return Valid; }
+  bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
+  bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
+  bool isVector(unsigned Width) const {
+    return isVector() && ElementBitwidth == Width;
+  }
+  bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
+  bool isSignedInteger() const {
+    return ScalarType == ScalarTypeKind::SignedInteger;
+  }
+  bool isFloatVector(unsigned Width) const {
+    return isVector() && isFloat() && ElementBitwidth == Width;
+  }
+  bool isFloat(unsigned Width) const {
+    return isFloat() && ElementBitwidth == Width;
+  }
+
+private:
+  // Verify RVV vector type and set Valid.
+  bool verifyType() const;
+
+  // Creates a type based on basic types of TypeRange
+  void applyBasicType();
+
+  // Applies a prototype modifier to the current type. The result maybe an
+  // invalid type.
+  void applyModifier(StringRef prototype);
+
+  // Compute and record a string for legal type.
+  void initBuiltinStr();
+  // Compute and record a builtin RVV vector type string.
+  void initClangBuiltinStr();
+  // Compute and record a type string for used in the header.
+  void initTypeStr();
+  // Compute and record a short name of a type for C/C++ name suffix.
+  void initShortStr();
+};
+
+using RVVTypePtr = RVVType *;
+using RVVTypes = std::vector<RVVTypePtr>;
+using RISCVPredefinedMacroT = uint8_t;
+
+enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
+  Basic = 0,
+  V = 1 << 1,
+  Zvfh = 1 << 2,
+  RV64 = 1 << 3,
+  VectorMaxELen64 = 1 << 4,
+  VectorMaxELenFp32 = 1 << 5,
+  VectorMaxELenFp64 = 1 << 6,
+};
+
+enum PolicyScheme : uint8_t {
+  SchemeNone,
+  HasPassthruOperand,
+  HasPolicyOperand,
+};
+
+// TODO refactor RVVIntrinsic class design after support all intrinsic
+// combination. This represents an instantiation of an intrinsic with a
+// particular type and prototype
+class RVVIntrinsic {
+
+private:
+  std::string BuiltinName; // Builtin name
+  std::string Name;        // C intrinsic name.
+  std::string MangledName;
+  std::string IRName;
+  bool IsMasked;
+  bool HasVL;
+  PolicyScheme Scheme;
+  bool HasUnMaskedOverloaded;
+  bool HasBuiltinAlias;
+  std::string ManualCodegen;
+  RVVTypePtr OutputType; // Builtin output type
+  RVVTypes InputTypes;   // Builtin input types
+  // The types we use to obtain the specific LLVM intrinsic. They are index of
+  // InputTypes. -1 means the return type.
+  std::vector<int64_t> IntrinsicTypes;
+  RISCVPredefinedMacroT RISCVPredefinedMacros = 0;
+  unsigned NF = 1;
+
+public:
+  RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
+               StringRef MangledSuffix, StringRef IRName, bool IsMasked,
+               bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
+               bool HasUnMaskedOverloaded, bool HasBuiltinAlias,
+               StringRef ManualCodegen, const RVVTypes &Types,
+               const std::vector<int64_t> &IntrinsicTypes,
+               const std::vector<StringRef> &RequiredFeatures, unsigned NF);
+  ~RVVIntrinsic() = default;
+
+  StringRef getBuiltinName() const { return BuiltinName; }
+  StringRef getName() const { return Name; }
+  StringRef getMangledName() const { return MangledName; }
+  bool hasVL() const { return HasVL; }
+  bool hasPolicy() const { return Scheme != SchemeNone; }
+  bool hasPassthruOperand() const { return Scheme == HasPassthruOperand; }
+  bool hasPolicyOperand() const { return Scheme == HasPolicyOperand; }
+  bool hasUnMaskedOverloaded() const { return HasUnMaskedOverloaded; }
+  bool hasBuiltinAlias() const { return HasBuiltinAlias; }
+  bool hasManualCodegen() const { return !ManualCodegen.empty(); }
+  bool isMasked() const { return IsMasked; }
+  StringRef getIRName() const { return IRName; }
+  StringRef getManualCodegen() const { return ManualCodegen; }
+  PolicyScheme getPolicyScheme() const { return Scheme; }
+  RISCVPredefinedMacroT getRISCVPredefinedMacros() const {
+    return RISCVPredefinedMacros;
+  }
+  unsigned getNF() const { return NF; }
+  const std::vector<int64_t> &getIntrinsicTypes() const {
+    return IntrinsicTypes;
+  }
+
+  // Return the type string for a BUILTIN() macro in Builtins.def.
+  std::string getBuiltinTypeStr() const;
+
+  // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
+  // init the RVVIntrinsic ID and IntrinsicTypes.
+  void emitCodeGenSwitchBody(raw_ostream &o) const;
+
+  // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
+  void emitIntrinsicFuncDef(raw_ostream &o) const;
+
+  // Emit the mangled function definition.
+  void emitMangledFuncDef(raw_ostream &o) const;
+};
+
+} // end namespace RISCV
+
+} // end namespace llvm
+
+#endif // LLVM_SUPPORT_RISCVVINTRINSICUTILS_H

diff  --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 7cbff3dddbcdf..fae4e0ef7b58a 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -191,6 +191,7 @@ add_llvm_component_library(LLVMSupport
   RISCVAttributes.cpp
   RISCVAttributeParser.cpp
   RISCVISAInfo.cpp
+  RISCVVIntrinsicUtils.cpp
   ScaledNumber.cpp
   ScopedPrinter.cpp
   SHA1.cpp

diff  --git a/llvm/lib/Support/RISCVVIntrinsicUtils.cpp b/llvm/lib/Support/RISCVVIntrinsicUtils.cpp
new file mode 100644
index 0000000000000..53a7baae11dbc
--- /dev/null
+++ b/llvm/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -0,0 +1,668 @@
+//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/RISCVVIntrinsicUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include <numeric>
+
+namespace llvm {
+namespace RISCV {
+
+//===----------------------------------------------------------------------===//
+// Type implementation
+//===----------------------------------------------------------------------===//
+
+LMULType::LMULType(int NewLog2LMUL) {
+  // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
+  assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
+  Log2LMUL = NewLog2LMUL;
+}
+
+std::string LMULType::str() const {
+  if (Log2LMUL < 0)
+    return "mf" + utostr(1ULL << (-Log2LMUL));
+  return "m" + utostr(1ULL << Log2LMUL);
+}
+
+VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
+  int Log2ScaleResult = 0;
+  switch (ElementBitwidth) {
+  default:
+    break;
+  case 8:
+    Log2ScaleResult = Log2LMUL + 3;
+    break;
+  case 16:
+    Log2ScaleResult = Log2LMUL + 2;
+    break;
+  case 32:
+    Log2ScaleResult = Log2LMUL + 1;
+    break;
+  case 64:
+    Log2ScaleResult = Log2LMUL;
+    break;
+  }
+  // Illegal vscale result would be less than 1
+  if (Log2ScaleResult < 0)
+    return llvm::None;
+  return 1 << Log2ScaleResult;
+}
+
+void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
+
+LMULType &LMULType::operator*=(uint32_t RHS) {
+  assert(isPowerOf2_32(RHS));
+  this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
+  return *this;
+}
+
+RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
+    : BT(BT), LMUL(LMULType(Log2LMUL)) {
+  applyBasicType();
+  applyModifier(prototype);
+  Valid = verifyType();
+  if (Valid) {
+    initBuiltinStr();
+    initTypeStr();
+    if (isVector()) {
+      initClangBuiltinStr();
+    }
+  }
+}
+
+// clang-format off
+// boolean type are encoded the ratio of n (SEW/LMUL)
+// SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
+// c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
+// IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
+
+// type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
+// --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
+// i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
+// i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
+// i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
+// i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
+// double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
+// float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
+// half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
+// clang-format on
+
+bool RVVType::verifyType() const {
+  if (ScalarType == Invalid)
+    return false;
+  if (isScalar())
+    return true;
+  if (!Scale.hasValue())
+    return false;
+  if (isFloat() && ElementBitwidth == 8)
+    return false;
+  unsigned V = Scale.getValue();
+  switch (ElementBitwidth) {
+  case 1:
+  case 8:
+    // Check Scale is 1,2,4,8,16,32,64
+    return (V <= 64 && isPowerOf2_32(V));
+  case 16:
+    // Check Scale is 1,2,4,8,16,32
+    return (V <= 32 && isPowerOf2_32(V));
+  case 32:
+    // Check Scale is 1,2,4,8,16
+    return (V <= 16 && isPowerOf2_32(V));
+  case 64:
+    // Check Scale is 1,2,4,8
+    return (V <= 8 && isPowerOf2_32(V));
+  }
+  return false;
+}
+
+void RVVType::initBuiltinStr() {
+  assert(isValid() && "RVVType is invalid");
+  switch (ScalarType) {
+  case ScalarTypeKind::Void:
+    BuiltinStr = "v";
+    return;
+  case ScalarTypeKind::Size_t:
+    BuiltinStr = "z";
+    if (IsImmediate)
+      BuiltinStr = "I" + BuiltinStr;
+    if (IsPointer)
+      BuiltinStr += "*";
+    return;
+  case ScalarTypeKind::Ptr
diff _t:
+    BuiltinStr = "Y";
+    return;
+  case ScalarTypeKind::UnsignedLong:
+    BuiltinStr = "ULi";
+    return;
+  case ScalarTypeKind::SignedLong:
+    BuiltinStr = "Li";
+    return;
+  case ScalarTypeKind::Boolean:
+    assert(ElementBitwidth == 1);
+    BuiltinStr += "b";
+    break;
+  case ScalarTypeKind::SignedInteger:
+  case ScalarTypeKind::UnsignedInteger:
+    switch (ElementBitwidth) {
+    case 8:
+      BuiltinStr += "c";
+      break;
+    case 16:
+      BuiltinStr += "s";
+      break;
+    case 32:
+      BuiltinStr += "i";
+      break;
+    case 64:
+      BuiltinStr += "Wi";
+      break;
+    default:
+      llvm_unreachable("Unhandled ElementBitwidth!");
+    }
+    if (isSignedInteger())
+      BuiltinStr = "S" + BuiltinStr;
+    else
+      BuiltinStr = "U" + BuiltinStr;
+    break;
+  case ScalarTypeKind::Float:
+    switch (ElementBitwidth) {
+    case 16:
+      BuiltinStr += "x";
+      break;
+    case 32:
+      BuiltinStr += "f";
+      break;
+    case 64:
+      BuiltinStr += "d";
+      break;
+    default:
+      llvm_unreachable("Unhandled ElementBitwidth!");
+    }
+    break;
+  default:
+    llvm_unreachable("ScalarType is invalid!");
+  }
+  if (IsImmediate)
+    BuiltinStr = "I" + BuiltinStr;
+  if (isScalar()) {
+    if (IsConstant)
+      BuiltinStr += "C";
+    if (IsPointer)
+      BuiltinStr += "*";
+    return;
+  }
+  BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
+  // Pointer to vector types. Defined for segment load intrinsics.
+  // segment load intrinsics have pointer type arguments to store the loaded
+  // vector values.
+  if (IsPointer)
+    BuiltinStr += "*";
+}
+
+void RVVType::initClangBuiltinStr() {
+  assert(isValid() && "RVVType is invalid");
+  assert(isVector() && "Handle Vector type only");
+
+  ClangBuiltinStr = "__rvv_";
+  switch (ScalarType) {
+  case ScalarTypeKind::Boolean:
+    ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
+    return;
+  case ScalarTypeKind::Float:
+    ClangBuiltinStr += "float";
+    break;
+  case ScalarTypeKind::SignedInteger:
+    ClangBuiltinStr += "int";
+    break;
+  case ScalarTypeKind::UnsignedInteger:
+    ClangBuiltinStr += "uint";
+    break;
+  default:
+    llvm_unreachable("ScalarTypeKind is invalid");
+  }
+  ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
+}
+
+void RVVType::initTypeStr() {
+  assert(isValid() && "RVVType is invalid");
+
+  if (IsConstant)
+    Str += "const ";
+
+  auto getTypeString = [&](StringRef TypeStr) {
+    if (isScalar())
+      return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
+    return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
+        .str();
+  };
+
+  switch (ScalarType) {
+  case ScalarTypeKind::Void:
+    Str = "void";
+    return;
+  case ScalarTypeKind::Size_t:
+    Str = "size_t";
+    if (IsPointer)
+      Str += " *";
+    return;
+  case ScalarTypeKind::Ptr
diff _t:
+    Str = "ptr
diff _t";
+    return;
+  case ScalarTypeKind::UnsignedLong:
+    Str = "unsigned long";
+    return;
+  case ScalarTypeKind::SignedLong:
+    Str = "long";
+    return;
+  case ScalarTypeKind::Boolean:
+    if (isScalar())
+      Str += "bool";
+    else
+      // Vector bool is special case, the formulate is
+      // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
+      Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
+    break;
+  case ScalarTypeKind::Float:
+    if (isScalar()) {
+      if (ElementBitwidth == 64)
+        Str += "double";
+      else if (ElementBitwidth == 32)
+        Str += "float";
+      else if (ElementBitwidth == 16)
+        Str += "_Float16";
+      else
+        llvm_unreachable("Unhandled floating type.");
+    } else
+      Str += getTypeString("float");
+    break;
+  case ScalarTypeKind::SignedInteger:
+    Str += getTypeString("int");
+    break;
+  case ScalarTypeKind::UnsignedInteger:
+    Str += getTypeString("uint");
+    break;
+  default:
+    llvm_unreachable("ScalarType is invalid!");
+  }
+  if (IsPointer)
+    Str += " *";
+}
+
+void RVVType::initShortStr() {
+  switch (ScalarType) {
+  case ScalarTypeKind::Boolean:
+    assert(isVector());
+    ShortStr = "b" + utostr(64 / Scale.getValue());
+    return;
+  case ScalarTypeKind::Float:
+    ShortStr = "f" + utostr(ElementBitwidth);
+    break;
+  case ScalarTypeKind::SignedInteger:
+    ShortStr = "i" + utostr(ElementBitwidth);
+    break;
+  case ScalarTypeKind::UnsignedInteger:
+    ShortStr = "u" + utostr(ElementBitwidth);
+    break;
+  default:
+    llvm_unreachable("Unhandled case!");
+  }
+  if (isVector())
+    ShortStr += LMUL.str();
+}
+
+void RVVType::applyBasicType() {
+  switch (BT) {
+  case 'c':
+    ElementBitwidth = 8;
+    ScalarType = ScalarTypeKind::SignedInteger;
+    break;
+  case 's':
+    ElementBitwidth = 16;
+    ScalarType = ScalarTypeKind::SignedInteger;
+    break;
+  case 'i':
+    ElementBitwidth = 32;
+    ScalarType = ScalarTypeKind::SignedInteger;
+    break;
+  case 'l':
+    ElementBitwidth = 64;
+    ScalarType = ScalarTypeKind::SignedInteger;
+    break;
+  case 'x':
+    ElementBitwidth = 16;
+    ScalarType = ScalarTypeKind::Float;
+    break;
+  case 'f':
+    ElementBitwidth = 32;
+    ScalarType = ScalarTypeKind::Float;
+    break;
+  case 'd':
+    ElementBitwidth = 64;
+    ScalarType = ScalarTypeKind::Float;
+    break;
+  default:
+    llvm_unreachable("Unhandled type code!");
+  }
+  assert(ElementBitwidth != 0 && "Bad element bitwidth!");
+}
+
+void RVVType::applyModifier(StringRef Transformer) {
+  if (Transformer.empty())
+    return;
+  // Handle primitive type transformer
+  auto PType = Transformer.back();
+  switch (PType) {
+  case 'e':
+    Scale = 0;
+    break;
+  case 'v':
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case 'w':
+    ElementBitwidth *= 2;
+    LMUL *= 2;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case 'q':
+    ElementBitwidth *= 4;
+    LMUL *= 4;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case 'o':
+    ElementBitwidth *= 8;
+    LMUL *= 8;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case 'm':
+    ScalarType = ScalarTypeKind::Boolean;
+    Scale = LMUL.getScale(ElementBitwidth);
+    ElementBitwidth = 1;
+    break;
+  case '0':
+    ScalarType = ScalarTypeKind::Void;
+    break;
+  case 'z':
+    ScalarType = ScalarTypeKind::Size_t;
+    break;
+  case 't':
+    ScalarType = ScalarTypeKind::Ptr
diff _t;
+    break;
+  case 'u':
+    ScalarType = ScalarTypeKind::UnsignedLong;
+    break;
+  case 'l':
+    ScalarType = ScalarTypeKind::SignedLong;
+    break;
+  default:
+    llvm_unreachable("Illegal primitive type transformers!");
+  }
+  Transformer = Transformer.drop_back();
+
+  // Extract and compute complex type transformer. It can only appear one time.
+  if (Transformer.startswith("(")) {
+    size_t Idx = Transformer.find(')');
+    assert(Idx != StringRef::npos);
+    StringRef ComplexType = Transformer.slice(1, Idx);
+    Transformer = Transformer.drop_front(Idx + 1);
+    assert(!Transformer.contains('(') &&
+           "Only allow one complex type transformer");
+
+    auto UpdateAndCheckComplexProto = [&]() {
+      Scale = LMUL.getScale(ElementBitwidth);
+      const StringRef VectorPrototypes("vwqom");
+      if (!VectorPrototypes.contains(PType))
+        llvm_unreachable("Complex type transformer only supports vector type!");
+      if (Transformer.find_first_of("PCKWS") != StringRef::npos)
+        llvm_unreachable(
+            "Illegal type transformer for Complex type transformer");
+    };
+    auto ComputeFixedLog2LMUL =
+        [&](StringRef Value,
+            std::function<bool(const int32_t &, const int32_t &)> Compare) {
+          int32_t Log2LMUL;
+          Value.getAsInteger(10, Log2LMUL);
+          if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
+            ScalarType = Invalid;
+            return false;
+          }
+          // Update new LMUL
+          LMUL = LMULType(Log2LMUL);
+          UpdateAndCheckComplexProto();
+          return true;
+        };
+    auto ComplexTT = ComplexType.split(":");
+    if (ComplexTT.first == "Log2EEW") {
+      uint32_t Log2EEW;
+      ComplexTT.second.getAsInteger(10, Log2EEW);
+      // update new elmul = (eew/sew) * lmul
+      LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
+      // update new eew
+      ElementBitwidth = 1 << Log2EEW;
+      ScalarType = ScalarTypeKind::SignedInteger;
+      UpdateAndCheckComplexProto();
+    } else if (ComplexTT.first == "FixedSEW") {
+      uint32_t NewSEW;
+      ComplexTT.second.getAsInteger(10, NewSEW);
+      // Set invalid type if src and dst SEW are same.
+      if (ElementBitwidth == NewSEW) {
+        ScalarType = Invalid;
+        return;
+      }
+      // Update new SEW
+      ElementBitwidth = NewSEW;
+      UpdateAndCheckComplexProto();
+    } else if (ComplexTT.first == "LFixedLog2LMUL") {
+      // New LMUL should be larger than old
+      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
+        return;
+    } else if (ComplexTT.first == "SFixedLog2LMUL") {
+      // New LMUL should be smaller than old
+      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
+        return;
+    } else {
+      llvm_unreachable("Illegal complex type transformers!");
+    }
+  }
+
+  // Compute the remain type transformers
+  for (char I : Transformer) {
+    switch (I) {
+    case 'P':
+      if (IsConstant)
+        llvm_unreachable("'P' transformer cannot be used after 'C'");
+      if (IsPointer)
+        llvm_unreachable("'P' transformer cannot be used twice");
+      IsPointer = true;
+      break;
+    case 'C':
+      if (IsConstant)
+        llvm_unreachable("'C' transformer cannot be used twice");
+      IsConstant = true;
+      break;
+    case 'K':
+      IsImmediate = true;
+      break;
+    case 'U':
+      ScalarType = ScalarTypeKind::UnsignedInteger;
+      break;
+    case 'I':
+      ScalarType = ScalarTypeKind::SignedInteger;
+      break;
+    case 'F':
+      ScalarType = ScalarTypeKind::Float;
+      break;
+    case 'S':
+      LMUL = LMULType(0);
+      // Update ElementBitwidth need to update Scale too.
+      Scale = LMUL.getScale(ElementBitwidth);
+      break;
+    default:
+      llvm_unreachable("Illegal non-primitive type transformer!");
+    }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// RVVIntrinsic implementation
+//===----------------------------------------------------------------------===//
+RVVIntrinsic::RVVIntrinsic(
+    StringRef NewName, StringRef Suffix, StringRef NewMangledName,
+    StringRef MangledSuffix, StringRef IRName, bool IsMasked,
+    bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
+    bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
+    const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
+    const std::vector<StringRef> &RequiredFeatures, unsigned NF)
+    : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
+      HasUnMaskedOverloaded(HasUnMaskedOverloaded),
+      HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
+      NF(NF) {
+
+  // Init BuiltinName, Name and MangledName
+  BuiltinName = NewName.str();
+  Name = BuiltinName;
+  if (NewMangledName.empty())
+    MangledName = NewName.split("_").first.str();
+  else
+    MangledName = NewMangledName.str();
+  if (!Suffix.empty())
+    Name += "_" + Suffix.str();
+  if (!MangledSuffix.empty())
+    MangledName += "_" + MangledSuffix.str();
+  if (IsMasked) {
+    BuiltinName += "_m";
+    Name += "_m";
+  }
+
+  // Init RISC-V extensions
+  for (const auto &T : OutInTypes) {
+    if (T->isFloatVector(16) || T->isFloat(16))
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh;
+    if (T->isFloatVector(32))
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32;
+    if (T->isFloatVector(64))
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64;
+    if (T->isVector(64))
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64;
+  }
+  for (auto Feature : RequiredFeatures) {
+    if (Feature == "RV64")
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64;
+    // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64
+    // require V.
+    if (Feature == "FullMultiply" &&
+        (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64))
+      RISCVPredefinedMacros |= RISCVPredefinedMacro::V;
+  }
+
+  // Init OutputType and InputTypes
+  OutputType = OutInTypes[0];
+  InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
+
+  // IntrinsicTypes is unmasked TA version index. Need to update it
+  // if there is merge operand (It is always in first operand).
+  IntrinsicTypes = NewIntrinsicTypes;
+  if ((IsMasked && HasMaskedOffOperand) ||
+      (!IsMasked && hasPassthruOperand())) {
+    for (auto &I : IntrinsicTypes) {
+      if (I >= 0)
+        I += NF;
+    }
+  }
+}
+
+std::string RVVIntrinsic::getBuiltinTypeStr() const {
+  std::string S;
+  S += OutputType->getBuiltinStr();
+  for (const auto &T : InputTypes) {
+    S += T->getBuiltinStr();
+  }
+  return S;
+}
+
+void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
+  if (!getIRName().empty())
+    OS << "  ID = Intrinsic::riscv_" + getIRName() + ";\n";
+  if (NF >= 2)
+    OS << "  NF = " + utostr(getNF()) + ";\n";
+  if (hasManualCodegen()) {
+    OS << ManualCodegen;
+    OS << "break;\n";
+    return;
+  }
+
+  if (isMasked()) {
+    if (hasVL()) {
+      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
+      if (hasPolicyOperand())
+        OS << "  Ops.push_back(ConstantInt::get(Ops.back()->getType(),"
+              " TAIL_UNDISTURBED));\n";
+    } else {
+      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
+    }
+  } else {
+    if (hasPolicyOperand())
+      OS << "  Ops.push_back(ConstantInt::get(Ops.back()->getType(), "
+            "TAIL_UNDISTURBED));\n";
+    else if (hasPassthruOperand()) {
+      OS << "  Ops.push_back(llvm::UndefValue::get(ResultType));\n";
+      OS << "  std::rotate(Ops.rbegin(), Ops.rbegin() + 1,  Ops.rend());\n";
+    }
+  }
+
+  OS << "  IntrinsicTypes = {";
+  ListSeparator LS;
+  for (const auto &Idx : IntrinsicTypes) {
+    if (Idx == -1)
+      OS << LS << "ResultType";
+    else
+      OS << LS << "Ops[" << Idx << "]->getType()";
+  }
+
+  // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
+  // always last operand.
+  if (hasVL())
+    OS << ", Ops.back()->getType()";
+  OS << "};\n";
+  OS << "  break;\n";
+}
+
+void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const {
+  OS << "__attribute__((__clang_builtin_alias__(";
+  OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
+  OS << OutputType->getTypeStr() << " " << getName() << "(";
+  // Emit function arguments
+  if (!InputTypes.empty()) {
+    ListSeparator LS;
+    for (unsigned i = 0; i < InputTypes.size(); ++i)
+      OS << LS << InputTypes[i]->getTypeStr();
+  }
+  OS << ");\n";
+}
+
+void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
+  OS << "__attribute__((__clang_builtin_alias__(";
+  OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
+  OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
+  // Emit function arguments
+  if (!InputTypes.empty()) {
+    ListSeparator LS;
+    for (unsigned i = 0; i < InputTypes.size(); ++i)
+      OS << LS << InputTypes[i]->getTypeStr();
+  }
+  OS << ");\n";
+}
+
+} // end namespace RISCV
+} // end namespace llvm


        


More information about the llvm-commits mailing list