[Mlir-commits] [mlir] 6144042 - [mlir][llvm] More LLVMFixed/ScalableVectorType to TypeDef

Jeff Niu llvmlistbot at llvm.org
Fri Oct 21 15:13:25 PDT 2022


Author: Jeff Niu
Date: 2022-10-21T15:13:11-07:00
New Revision: 6144042c27a9a0b8ea90e49d82e800cde02a4e88

URL: https://github.com/llvm/llvm-project/commit/6144042c27a9a0b8ea90e49d82e800cde02a4e88
DIFF: https://github.com/llvm/llvm-project/commit/6144042c27a9a0b8ea90e49d82e800cde02a4e88.diff

LOG: [mlir][llvm] More LLVMFixed/ScalableVectorType to TypeDef

This keeps the current parser, however, since the mnemonic `vec` is
overloaded for both of these types.

Depends on D136499

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136505

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 08bcaaeaa14e..048546a04b2c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -205,93 +205,6 @@ class LLVMStructType
                                    ArrayRef<Type> replTypes) const;
 };
 
-//===----------------------------------------------------------------------===//
-// LLVMFixedVectorType.
-//===----------------------------------------------------------------------===//
-
-/// LLVM dialect fixed vector type, represents a sequence of elements of known
-/// length that can be processed as one.
-class LLVMFixedVectorType
-    : public Type::TypeBase<LLVMFixedVectorType, Type,
-                            detail::LLVMTypeAndSizeStorage,
-                            SubElementTypeInterface::Trait> {
-public:
-  /// Inherit base constructor.
-  using Base::Base;
-  using Base::getChecked;
-
-  /// Gets or creates a fixed vector type containing `numElements` of
-  /// `elementType` in the same context as `elementType`.
-  static LLVMFixedVectorType get(Type elementType, unsigned numElements);
-  static LLVMFixedVectorType
-  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
-             unsigned numElements);
-
-  /// Checks if the given type can be used in a vector type. This type supports
-  /// only a subset of LLVM dialect types that don't have a built-in
-  /// counter-part, e.g., pointers.
-  static bool isValidElementType(Type type);
-
-  /// Returns the element type of the vector.
-  Type getElementType() const;
-
-  /// Returns the number of elements in the fixed vector.
-  unsigned getNumElements() const;
-
-  /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type elementType, unsigned numElements);
-
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const;
-  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const;
-};
-
-//===----------------------------------------------------------------------===//
-// LLVMScalableVectorType.
-//===----------------------------------------------------------------------===//
-
-/// LLVM dialect scalable vector type, represents a sequence of elements of
-/// unknown length that is known to be divisible by some constant. These
-/// elements can be processed as one in SIMD context.
-class LLVMScalableVectorType
-    : public Type::TypeBase<LLVMScalableVectorType, Type,
-                            detail::LLVMTypeAndSizeStorage,
-                            SubElementTypeInterface::Trait> {
-public:
-  /// Inherit base constructor.
-  using Base::Base;
-  using Base::getChecked;
-
-  /// Gets or creates a scalable vector type containing a non-zero multiple of
-  /// `minNumElements` of `elementType` in the same context as `elementType`.
-  static LLVMScalableVectorType get(Type elementType, unsigned minNumElements);
-  static LLVMScalableVectorType
-  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
-             unsigned minNumElements);
-
-  /// Checks if the given type can be used in a vector type.
-  static bool isValidElementType(Type type);
-
-  /// Returns the element type of the vector.
-  Type getElementType() const;
-
-  /// Returns the scaling factor of the number of elements in the vector. The
-  /// vector contains at least the resulting number of elements, or any non-zero
-  /// multiple of this number.
-  unsigned getMinNumElements() const;
-
-  /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type elementType, unsigned minNumElements);
-
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const;
-  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const;
-};
-
 //===----------------------------------------------------------------------===//
 // Printing and parsing.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 960953866469..9d0203d90ed4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -167,4 +167,66 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMFixedVectorType
+//===----------------------------------------------------------------------===//
+
+def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+  let summary = "LLVM fixed vector type";
+  let description = [{
+    LLVM dialect scalable vector type, represents a sequence of elements of
+    unknown length that is known to be divisible by some constant. These
+    elements can be processed as one in SIMD context.
+  }];
+
+  let parameters = (ins "Type":$elementType, "unsigned":$numElements);
+  let assemblyFormat = [{
+    `<` $numElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
+  }];
+
+  let genVerifyDecl = 1;
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType,
+                                        "unsigned":$numElements)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Checks if the given type can be used in a vector type.
+    static bool isValidElementType(Type type);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// LLVMScalableVectorType
+//===----------------------------------------------------------------------===//
+
+def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+  let summary = "LLVM scalable vector type";
+  let description = [{
+    LLVM dialect scalable vector type, represents a sequence of elements of
+    unknown length that is known to be divisible by some constant. These
+    elements can be processed as one in SIMD context.
+  }];
+
+  let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
+  let assemblyFormat = [{
+    `<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
+  }];
+
+  let genVerifyDecl = 1;
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType,
+                                        "unsigned":$minNumElements)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Checks if the given type can be used in a vector type.
+    static bool isValidElementType(Type type);
+  }];
+}
+
 #endif // LLVMTYPES_TD

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index de87d8e34e79..3ec212add15a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2573,8 +2573,6 @@ void LLVMDialect::initialize() {
            LLVMTokenType,
            LLVMLabelType,
            LLVMMetadataType,
-           LLVMFixedVectorType,
-           LLVMScalableVectorType,
            LLVMStructType>();
   // clang-format on
   registerTypes();

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 4a640221d999..fefaf6e634b8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -99,14 +99,6 @@ static void printStructType(AsmPrinter &printer, LLVMStructType type) {
   printer << '>';
 }
 
-/// Prints a type containing a fixed number of elements.
-template <typename TypeTy>
-static void printVectorType(AsmPrinter &printer, TypeTy type) {
-  printer << '<' << type.getNumElements() << " x ";
-  dispatchPrint(printer, type.getElementType());
-  printer << '>';
-}
-
 /// Prints the given LLVM dialect type recursively. This leverages closedness of
 /// the LLVM dialect type system to avoid printing the dialect prefix
 /// repeatedly. For recursive structures, only prints the name of the structure
@@ -124,26 +116,13 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
 
   printer << getTypeKeyword(type);
 
-  if (auto ptrType = type.dyn_cast<LLVMPointerType>())
-    return ptrType.print(printer);
-
-  if (auto arrayType = type.dyn_cast<LLVMArrayType>())
-    return arrayType.print(printer);
-  if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
-    return printVectorType(printer, vectorType);
-
-  if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
-    printer << "<? x " << vectorType.getMinNumElements() << " x ";
-    dispatchPrint(printer, vectorType.getElementType());
-    printer << '>';
-    return;
-  }
-
-  if (auto structType = type.dyn_cast<LLVMStructType>())
-    return printStructType(printer, structType);
-
-  if (auto funcType = type.dyn_cast<LLVMFunctionType>())
-    return funcType.print(printer);
+  llvm::TypeSwitch<Type>(type)
+      .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
+            LLVMScalableVectorType, LLVMFunctionType>(
+          [&](auto type) { type.print(printer); })
+      .Case([&](LLVMStructType structType) {
+        printStructType(printer, structType);
+      });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index f37247331689..9187814a2b48 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -735,14 +735,6 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           numElements);
 }
 
-Type LLVMFixedVectorType::getElementType() const {
-  return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
-}
-
-unsigned LLVMFixedVectorType::getNumElements() const {
-  return getImpl()->numElements;
-}
-
 bool LLVMFixedVectorType::isValidElementType(Type type) {
   return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
 }
@@ -783,14 +775,6 @@ LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           minNumElements);
 }
 
-Type LLVMScalableVectorType::getElementType() const {
-  return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
-}
-
-unsigned LLVMScalableVectorType::getMinNumElements() const {
-  return getImpl()->numElements;
-}
-
 bool LLVMScalableVectorType::isValidElementType(Type type) {
   if (auto intType = type.dyn_cast<IntegerType>())
     return intType.isSignless();


        


More information about the Mlir-commits mailing list