[flang-commits] [flang] f634280 - [flang][fir] Add FIR's vector type.

Eric Schweitz via flang-commits flang-commits at lists.llvm.org
Fri Feb 5 12:44:38 PST 2021


Author: Eric Schweitz
Date: 2021-02-05T12:44:19-08:00
New Revision: f6342806dbfcae5320013331958dae98855e75f6

URL: https://github.com/llvm/llvm-project/commit/f6342806dbfcae5320013331958dae98855e75f6
DIFF: https://github.com/llvm/llvm-project/commit/f6342806dbfcae5320013331958dae98855e75f6.diff

LOG: [flang][fir] Add FIR's vector type.

This patch adds support for `!fir.vector`, a rank one, constant length
data type.

https://github.com/flang-compiler/f18-llvm-project/pull/413

Differential Revision: https://reviews.llvm.org/D96162

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/include/flang/Optimizer/Dialect/FIRType.h
    flang/lib/Optimizer/Dialect/FIRDialect.cpp
    flang/lib/Optimizer/Dialect/FIRType.cpp
    flang/test/Fir/fir-types.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index a863e5e46dd3..8f3670b29d74 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -40,6 +40,8 @@ def fir_LogicalType : Type<CPred<"$_self.isa<fir::LogicalType>()">,
     "FIR logical type">;
 def fir_RealType : Type<CPred<"$_self.isa<fir::RealType>()">,
     "FIR real type">;
+def fir_VectorType : Type<CPred<"$_self.isa<fir::VectorType>()">,
+    "FIR vector type">;
 
 // Generalized FIR and standard dialect types representing intrinsic types
 def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
@@ -61,7 +63,7 @@ def fir_SequenceType : Type<CPred<"$_self.isa<fir::SequenceType>()">,
 // Composable types
 def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
     fir_SequenceType.predicate, fir_ComplexType.predicate,
-    IsTupleTypePred]>, "any composite">;
+    fir_VectorType.predicate, IsTupleTypePred]>, "any composite">;
 
 // Reference to an entity type
 def fir_ReferenceType : Type<CPred<"$_self.isa<fir::ReferenceType>()">,

diff  --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index a5f9dc5b7428..66a5e8bd7d18 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -55,6 +55,7 @@ struct RecordTypeStorage;
 struct ReferenceTypeStorage;
 struct SequenceTypeStorage;
 struct TypeDescTypeStorage;
+struct VectorTypeStorage;
 } // namespace detail
 
 // These isa_ routines follow the precedent of llvm::isa_or_null<>
@@ -363,14 +364,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
                                                           llvm::StringRef name);
 };
 
-mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
-
-void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
-
-/// Guarantee `type` is a scalar integral type (standard Integer, standard
-/// Index, or FIR Int). Aborts execution if condition is false.
-void verifyIntegralType(mlir::Type type);
-
 /// Is `t` a FIR Real or MLIR Float type?
 inline bool isa_real(mlir::Type t) {
   return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -382,6 +375,33 @@ inline bool isa_integer(mlir::Type t) {
          t.isa<fir::IntegerType>();
 }
 
+/// Replacement for the builtin vector type.
+/// The FIR vector type is always rank one. It's size is always a constant.
+/// A vector's element type must be real or integer.
+class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
+                                               detail::VectorTypeStorage> {
+public:
+  using Base::Base;
+
+  static fir::VectorType get(uint64_t len, mlir::Type eleTy);
+  mlir::Type getEleTy() const;
+  uint64_t getLen() const;
+
+  static mlir::LogicalResult
+  verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
+  static bool isValidElementType(mlir::Type t) {
+    return isa_real(t) || isa_integer(t);
+  }
+};
+
+mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
+
+void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
+
+/// Guarantee `type` is a scalar integral type (standard Integer, standard
+/// Index, or FIR Int). Aborts execution if condition is false.
+void verifyIntegralType(mlir::Type type);
+
 /// Is `t` a FIR or MLIR Complex type?
 inline bool isa_complex(mlir::Type t) {
   return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();

diff  --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
index 477bb1e65ccc..f174c899795a 100644
--- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
@@ -18,7 +18,7 @@ fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
   addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
            FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
            PointerType, RealType, RecordType, ReferenceType, SequenceType,
-           TypeDescType>();
+           TypeDescType, fir::VectorType>();
   addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
                 PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
   addOperations<

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index cec198311016..c52879fab6c2 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -142,6 +142,19 @@ TypeDescType parseTypeDesc(mlir::DialectAsmParser &parser, mlir::Location loc) {
   return parseTypeSingleton<TypeDescType>(parser, loc);
 }
 
+// `vector` `<` len `:` type `>`
+fir::VectorType parseVector(mlir::DialectAsmParser &parser,
+                            mlir::Location loc) {
+  int64_t len = 0;
+  mlir::Type eleTy;
+  if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
+      parser.parseType(eleTy) || parser.parseGreater()) {
+    parser.emitError(parser.getNameLoc(), "invalid vector type");
+    return {};
+  }
+  return fir::VectorType::get(len, eleTy);
+}
+
 // `void`
 mlir::Type parseVoid(mlir::DialectAsmParser &parser) {
   return parser.getBuilder().getNoneType();
@@ -346,6 +359,8 @@ mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
     return parseDerived(parser, loc);
   if (typeNameLit == "void")
     return parseVoid(parser);
+  if (typeNameLit == "vector")
+    return parseVector(parser, loc);
 
   parser.emitError(parser.getNameLoc(), "unknown FIR type " + typeNameLit);
   return {};
@@ -790,6 +805,39 @@ struct TypeDescTypeStorage : public mlir::TypeStorage {
   explicit TypeDescTypeStorage(mlir::Type ofTy) : ofTy{ofTy} {}
 };
 
+/// Vector type storage
+struct VectorTypeStorage : public mlir::TypeStorage {
+  using KeyTy = std::tuple<uint64_t, mlir::Type>;
+
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<uint64_t>(key),
+                              std::get<mlir::Type>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy{getLen(), getEleTy()};
+  }
+
+  static VectorTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    auto *storage = allocator.allocate<VectorTypeStorage>();
+    return new (storage)
+        VectorTypeStorage{std::get<uint64_t>(key), std::get<mlir::Type>(key)};
+  }
+
+  uint64_t getLen() const { return len; }
+  mlir::Type getEleTy() const { return eleTy; }
+
+protected:
+  uint64_t len;
+  mlir::Type eleTy;
+
+private:
+  VectorTypeStorage() = delete;
+  explicit VectorTypeStorage(uint64_t len, mlir::Type eleTy)
+      : len{len}, eleTy{eleTy} {}
+};
+
 } // namespace detail
 
 template <typename A, typename B>
@@ -1069,12 +1117,34 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
       eleTy.isa<BoxProcType>() || eleTy.isa<FieldType>() ||
       eleTy.isa<LenType>() || eleTy.isa<HeapType>() ||
       eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
-      eleTy.isa<TypeDescType>() || eleTy.isa<SequenceType>())
+      eleTy.isa<TypeDescType>() || eleTy.isa<fir::VectorType>() ||
+      eleTy.isa<SequenceType>())
     return mlir::emitError(loc, "cannot build an array of this element type: ")
            << eleTy << '\n';
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// Vector type
+//===----------------------------------------------------------------------===//
+
+fir::VectorType fir::VectorType::get(uint64_t len, mlir::Type eleTy) {
+  return Base::get(eleTy.getContext(), len, eleTy);
+}
+
+mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
+
+uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
+
+mlir::LogicalResult
+fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
+                                              mlir::Type eleTy) {
+  if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
+    return mlir::emitError(loc, "cannot build a vector of type ")
+           << eleTy << '\n';
+  return mlir::success();
+}
+
 // compare if two shapes are equivalent
 bool fir::operator==(const SequenceType::Shape &sh_1,
                      const SequenceType::Shape &sh_2) {
@@ -1302,4 +1372,10 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     os << '>';
     return;
   }
+  if (auto type = ty.dyn_cast<fir::VectorType>()) {
+    os << "vector<" << type.getLen() << ':';
+    p.printType(type.getEleTy());
+    os << '>';
+    return;
+  }
 }

diff  --git a/flang/test/Fir/fir-types.fir b/flang/test/Fir/fir-types.fir
index 789780809772..d7f2c2ae7fa4 100644
--- a/flang/test/Fir/fir-types.fir
+++ b/flang/test/Fir/fir-types.fir
@@ -71,3 +71,7 @@ func private @box5() -> !fir.box<!fir.type<derived3{f:f32}>>
 // CHECK-LABEL: func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
 func private @oth2() -> !fir.field
 func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
+
+// FIR vector
+// CHECK-LABEL: func private @vecty(i1) -> !fir.vector<10:i32>
+func private @vecty(i1) -> !fir.vector<10:i32>


        


More information about the flang-commits mailing list